开启辅助访问
 找回密码
 立即注册

机器学习系列(三):线性判别分析(fisher究极详解)

_shaman_ 回答数4 浏览数1259

  • 1.原理概述
  • 2.原理详解

    • 2.1 输入
    • 2.2 家用投影
    • 2.3 求解w

  • 3.代码实现
1.原理概述

我们的目的是将高维的数据家用投影到一维直线上并在家用投影的值中取一台阈值进行分类,如下图所示:(绘画水平有限,将就着看)

在上图,很明显左边的家用投影更适合分类,因为两种类别(o和x)在家用投影直线上能轻松地找到一台阈值将其区分开来,而右边的家用投影方向则不适合当前分类。
所以我们需要求解一台适合的家用投影方向w
在理解fisher的时候,我遇到了很多不理解问题,在经过多本书籍的对比之后终于搞懂了,其大致的思路如下:

  • 问题的初衷在于找到一条线将坐标点向该线上家用投影,将这条线的方向设为w,并用该w作为假设带入,最后解出最佳w
  • 按照我们假设的w,将样本点向该直线中家用投影,即w^Tx,求出每一类样本点在家用投影上的均值和方差(或者说是协方差矩阵)
  • 按照类间小,类内大的目标,设立目标函数求解w
值得注意的是,我们求得的w是最终家用投影的平面(在这里为一维直线)方向,而不是感知机或逻辑斯蒂回归中的决策边界,这个问题一度让我以为自个的w求错了!!!
另外,在推导公式中,一定一定记得随时查看当前数据维度,不然极易混淆
2.原理详解

2.1 输入

对于样本点X = (x_1,x_2,...,x_N)^T,其维度为N×p,即每一台样本有p个特征 ,其类别Y = (y_1,y_2,...,y_N)^T,其中y_i\in\{+1,-1\}
我们将样本按其标签分为两类,数量分别为N_1和N_2,即|x_{c1}|=N_1\ ,\ |x_{c2}| = N_2\ ,\ N_1+N_2 = N
在本文中,我们用坐标点家用投影在直线上来直观地表述,即p=2
2.2 家用投影

假设,我是说假设啊,我们目前找到了一台适合的家用投影方向,则样本x在直线z上的家用投影为:z_i = w^Tx_i这个公式的来源:我们假设w的模为1(因为方向重要长度不重要,可以缩放)。在两个向量中,你可以很明显地看出x_i在w上的家用投影为\Delta = x_icos\theta,而x_i和w的点积为x_i·w = |x_i||w|cos\theta = |x_i|cos\theta = \Delta,就是说我们可以用x_i和w的点积表示其x_i在w上的家用投影,写作w^Tx_i的形式。
值得注意的是,在这里家用投影指的是 :以原点为起点,x_i为终点的向量在直线z(方向为w)上的家用投影

然后,所有样本在z上的家用投影的均值为:\bar z = \frac{1}{N}\displaystyle\sum_{i=1}^Nz_i=\frac{1}{N}\displaystyle\sum_{i=1}^Nw^Tx_i
协方差(在这里可以理解为方差)为S_z=\frac{1}{N}\displaystyle\sum_{i=1}^N(z_i-\bar z)(z_i-\bar z)^T=\frac{1}{N}\displaystyle\sum_{i=1}^N(w^Tx_i-\bar z)(w^Tx_i-\bar z)^T
两个类分开写:
类别1:\bar {z_1} = \frac{1}{N_1}\displaystyle\sum_{i=1}^{N_1}w^Tx_i ,S_1=\frac{1}{N_1}\displaystyle\sum_{i=1}^{N_1}(w^Tx_i-\bar {z_1})(w^Tx_i-\bar {z_1})^T
类别2:\bar {z_2} = \frac{1}{N_2}\displaystyle\sum_{i=1}^{N_2}w^Tx_i ,S_1=\frac{1}{N_2}\displaystyle\sum_{i=1}^{N_2}(w^Tx_i-\bar {z_2})(w^Tx_i-\bar {z_2})^T
注意,这里由于我们是平面坐标点在一维直线上的家用投影,所以\bar z、S_z、\bar {z_1}、\bar{z_2}、S_1、S_2均可以理解为一台数。
还记得我们的目标吗:类内小,类间大
在这里,我们将类间表示为(\bar{z_1}-\bar{z_2})^2,即两个类别的样本分别取均值,其均值之差的平方
将类内表示为S_1+S_2,即两个类别的样本方差之和
由此我们可以很自然地得到一台最大化的目标函数:
J(w) = \frac{(\bar {z_1}-\bar{z_2})^2}{S_1+S_2} \\
极大化这个式子就相当于最小化S_1+S_2和最大化(\bar {z_1}-\bar{z_2})^2,完美符合我们的目标
对于J(w),我们可以进一步的调整:
对于其分子:
(\bar {z_1}-\bar{z_2})^2 = \left[\frac{1}{N_1}\displaystyle\sum_{i=1}^{N_1}w^Tx_i -  \frac{1}{N_2}\displaystyle\sum_{i=1}^{N_2}w^Tx_i\right]^2=\left[w^T(\frac{1}{N_1}\displaystyle\sum_{i=1}^{N_1}x_i-\frac{1}{N_2}\displaystyle\sum_{i=1}^{N_2}x_i)\right]^2=\left[w^T(\bar{x_{c1}}-\bar{x_{c2}})\right]^2=w^T(\bar{x_{c1}}-\bar{x_{c2}})(\bar{x_{c1}}-\bar{x_{c2}})^Tw \\
对于其分母,我们先取S1进行分析:
S_1 =\frac{1}{N_1}\displaystyle\sum_{i=1}^{N_1}(w^Tx_i-\bar {z_1})(w^Tx_i-\bar {z_1})^T = \frac{1}{N_1}\displaystyle\sum_{i=1}^{N_1}w^T(x_i-\bar{x_{c1}})(x_i-\bar{x_{c1}})^Tw = w^T\frac{1}{N_1}\displaystyle\sum_{i=1}^{N_1}(x_i-\bar{x_{c1}})(x_i-\bar{x_{c1}})^Tw=w^TS_{c1}w \\
上式中我们用S_{c1}表示\frac{1}{N_1}\displaystyle\sum_{i=1}^{N_1}(x_i-\bar{x_{c1}})(x_i-\bar{x_{c1}})^T,很显然,这个式子的意思为类别1中样本的方差
值得注意的是,这里的S_{c1}与之前的S_1有很大的不同!!!!!!S_1表示的是家用投影的方差,维度1×1,在家用投影中可以理解为一台数!一台数!一台数!!!!而S_{c1}表示类别1样本的协方差矩阵,没有家用投影之前的协方差矩阵,维度为p×p,是一台矩阵!一台矩阵!一台矩阵!
这个地方一定要搞清楚,不然你就不知道为啥前面是方差是一台数但是后面方差求出来是一台矩阵了,我当时推了几遍才反应过来!
此外,p×p的协方差矩阵为x_i的每一台特征的标准差与其他特征标准差的乘积
好了,书接上回:
S_2同理,所以其分母等价于:
S_1+S_2 =w^TS_{c1}w + w^TS_{c2}w = w^T(S_{c1}+S_{c2})w \\
则J(w)可以转化为:
J(w) = \frac{w^T(\bar{x_{c1}}-\bar{x_{c2}})(\bar{x_{c1}}-\bar{x_{c2}})^Tw}{ w^T(S_{c1}+S_{c2})w} = \frac{w^TS_bw}{w^TS_ww} \\
我们用S_b表示(\bar{x_{c1}}-\bar{x_{c2}})(\bar{x_{c1}}-\bar{x_{c2}})^T,其被称为类间方差
用S_w表示S_{c1}+S_{c2},其被称为类内方差
这里将J(w)转化的意义在于,将原来假设的家用投影w的方差和均值转化为了各类别样本本身的方差和均值,这样求解出来的w就可以直接用样本本来的值求解了,你甚至可以理解成把虚的变成实的!
2.3 求解w

由上面的式子,可以得到:
J(w) = \frac{w^TS_bw}{w^TS_ww} = w^TS_bw·(w^TS_ww)^{-1}
对w进行求导并令结果为0:
\frac{\partial J(w)}{\partial w} = 2S_bw·(w^TS_ww)^{-1}+w^TS_bw·(-1)(w^TS_ww)^{-2}·2S_ww = 0 \\
两边同时乘上(w^TS_ww)^2,注意这里(w^TS_ww)是一台数,维度计算:1×p\ · \ p×p\ ·\ p×1 = 1×1
可得:
S_bw(w^TS_ww)-w^TSbw·S_ww = 0\\ w^TSbw·S_ww = S_bw(w^TS_ww)  \\
同(w^TS_ww)一样,(w^TS_bw)依然是一台数,所以可得:
S_ww = \frac{w^TS_ww}{w^TS_bw}·S_b·w \\
两边左乘S_w^{-1},可得:
w =\frac{w^TS_ww}{w^TS_bw}·S_w^{-1}·S_b·w \\
由于\large \frac{w^TS_ww}{w^TS_bw}是一台数,并不会影响我们求w的方向,所以我们不妨设它为1(这里可以理解为我们家用投影的直线z可以放缩),即可得:
w = S_w^{-1}·S_b·w \\
我们继续研究,可以发现S_b·w = (\bar{x_{c1}}-\bar{x_{c2}})(\bar{x_{c1}}-\bar{x_{c2}})^T·w,而(\bar{x_{c1}}-\bar{x_{c2}})^T·w的维度计算为1×p · p×1 = 1×1,即实数,与w方向无关。
所以最终我们求得的w为:
w = S_w^{-1}(\bar{x_{c1}}-\bar{x_{c2}}) \\
即两个类协方差矩阵之和的逆矩阵乘上两个类的均值之差
3.代码实现

有了上述结果,我们可以很清楚的得到算法流程:

  • 根据给定的两个类的样本,分别计算其协方差矩阵和样本均值
  • 将协方差矩阵相加得到S_w,将样本均值相减得到(\bar{x_{c1}}-\bar{x_{c2}})
  • 计算w = S_w^{-1}(\bar{x_{c1}}-\bar{x_{c2}})
  • 将新样本根据计算出的w家用投影,并进行分类
以下为python实现代码:
import numpy as np
import matplotlib.pyplot as plt

def fisher(x_1 , y_1 , x_2 ,y_2):

    #类内方差 Sw = S1 + S2
    u_1 = np.mean(x_1 , axis = 0 )
    S1 = (x_1 - u_1).T.dot((x_1-u_1))

    u_2 = np.mean(x_2 , axis = 0 )
    S2 = (x_2 - u_2).T.dot(x_2-u_2)

    Sw = S1 + S2
    # print(Sw)
    w = np.linalg.pinv(Sw).dot(u_1 - u_2) #inv逆矩阵,pinv伪逆矩阵

    return w

def predict(test_data , w , c1 , c2): #根据计算的W进行预测
    u_1 = np.mean(c1, axis=0)
    u_2 = np.mean(c2, axis=0)

    diff_1 = w.dot(u_1.T)
    diff_2 = w.dot(u_2.T)

    diff_cur = w.dot(test_data.T) #根据家用投影距离远近决定是哪一类

    return [1 if abs(diff_cur-diff_1) < abs(diff_cur-diff_2) else -1 for i in range(len(test_data))]

if __name__ == '__mAIn__':
    X_True = [[1,1],[2,2],[0,4],[3,4]] #正例
    X_False = [[3,3],[4,5],[3,4],[5,4]] #负例
    Y_True = [1] * len(X_True)
    Y_False = [-1] * len(X_False)
    w = fisher(X_True,Y_True,X_False,Y_False) #求家用投影直线的方向w

    test_point = np.array([[2,5],[4,0]]) #测试数据
    predict_result = predict(test_point,w,X_True,X_False)
    print('测试点预测类别:'+str(predict_result))
   
    #绘图
    plot_x = np.arange(0,6)
    plot_y = -(w[0]*plot_x)/w[1]
    plt.scatter([x[0] for x in X_True] ,[x[1] for x in X_True] ,c = 'r',label = 'class 1')
    plt.scatter([x[0] for x in X_False] ,[x[1] for x in X_False] ,c = 'b',label = 'class -1')
    plt.scatter([x[0] for x in test_point] ,[x[1] for x in test_point] ,c = 'green',label = 'test_point')
    plt.plot(plot_x,plot_y,c = 'pink')
    plt.legend()
    plt.show()


运行结果:

图中的粉色直线即为所求结果,注意这个是家用投影的直线,而不是决策边界。
使用道具 举报
| 来自北京 用Deepseek满血版问问看
himuraken | 来自北京
我感觉应该是“类间大,类内小”吧
用Deepseek满血版问问看
回复
使用道具 举报
own | 未知
是的,之前笔误写错了,已经更正了,感谢你看的这么仔细,误导大家了很抱歉[捂脸]
回复
使用道具 举报
xmpig | 来自江苏
plot_y = -(w[0]*plot_x)/w[1] 这步不太明白,请问这里为什么有个负号啊?
回复
使用道具 举报
wugouxina | 来自北京
w1x+w2y=0,求y不就得变到等式右边加负号吗
回复
使用道具 举报
快速回复
您需要登录后才可以回帖 登录 | 立即注册

当贝投影