线性判别分类器由向量<math><semantics><mrow><mi>w</mi></mrow><annotation encoding="application/x-tex">w</annotation></semantics></math>w和偏差项<math><semantics><mrow><mi>b</mi></mrow><annotation encoding="application/x-tex">b</annotation></semantics></math>b构成。给定样例<math><semantics><mrow><mi>x</mi></mrow><annotation encoding="application/x-tex">x</annotation></semantics></math>x,其按照如下规则预测获得类别标记<math><semantics><mrow><mi>y</mi></mrow><annotation encoding="application/x-tex">y</annotation></semantics></math>y,即
<math><semantics><mrow><mi>y</mi><mo>=</mo><mi>s</mi><mi>i</mi><mi>g</mi><mi>n</mi><mo>(</mo><msup><mi>w</mi><mi>T</mi></msup><mi>x</mi><mo>+</mo><mi>b</mi><mo>)</mo></mrow><annotation encoding="application/x-tex">y=sign(w^Tx+b)</annotation></semantics></math>y=sign(wTx+b)
后面统一使用小写表示列向量,转置表示行向量。
分类过程分为如下两步:
- 首先,使用权重向量w将样本空间投影到直线上去
- 然后,寻找直线上一个点把正样本和负样本分开。
为了寻找最有的线性分类器,即<math><semantics><mrow><mi>w</mi></mrow><annotation encoding="application/x-tex">w</annotation></semantics></math>w和<math><semantics><mrow><mi>b</mi></mrow><annotation encoding="application/x-tex">b</annotation></semantics></math>b,一个经典的学习算法是线性判别分析(Fisher’s Linear Discriminant Analysis,LDA)。
简要来说,LDA的基本想法是使不同的样本尽量原理,使同类样本尽量靠近。
这一目标可以通过扩大不同类样本的类中心距离,同时缩小每个类的类内方差来实现。
在一个二分类数据集上,分别记所有正样本的的均值为<math><semantics><mrow><msub><mi>μ</mi><mo>+</mo></msub></mrow><annotation encoding="application/x-tex">\mu_+</annotation></semantics></math>μ+,协方差矩阵为<math><semantics><mrow><msub><mi mathvariant="normal">Σ</mi><mo>+</mo></msub></mrow><annotation encoding="application/x-tex">\Sigma_+</annotation></semantics></math>Σ+;所有负样本的的均值为<math><semantics><mrow><msub><mi>μ</mi><mo>−</mo></msub></mrow><annotation encoding="application/x-tex">\mu_-</annotation></semantics></math>μ−,协方差矩阵为<math><semantics><mrow><msub><mi mathvariant="normal">Σ</mi><mo>−</mo></msub></mrow><annotation encoding="application/x-tex">\Sigma_-</annotation></semantics></math>Σ−。
类间距离
投影后的类中心间距离为正类中心的投影点值减去负类投影点值:
<math><semantics><mrow><msub><mi>S</mi><mi>B</mi></msub><mo>(</mo><mi>w</mi><mo>)</mo><mo>=</mo><mo>(</mo><msup><mi>w</mi><mi>T</mi></msup><msub><mi>μ</mi><mo>+</mo></msub><mo>−</mo><msup><mi>w</mi><mi>T</mi></msup><msub><mi>μ</mi><mo>−</mo></msub><msup><mo>)</mo><mn>2</mn></msup></mrow><annotation encoding="application/x-tex">S_B(w)=(w^T\mu_+-w^T\mu_-)^2 </annotation></semantics></math>SB(w)=(wTμ+−wTμ−)2
类内距离
同时,类内方差可写为:
<math><semantics><mrow><msub><mi>S</mi><mi>W</mi></msub><mo>(</mo><mi>w</mi><mo>)</mo><mo>=</mo><mfrac><mrow><msub><mo>∑</mo><mi>x</mi></msub><mo>(</mo><msup><mi>w</mi><mi>T</mi></msup><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msup><mi>w</mi><mi>T</mi></msup><msub><mi>μ</mi><mo>+</mo></msub><msup><mo>)</mo><mn>2</mn></msup><mo>+</mo><msub><mo>∑</mo><mi>x</mi></msub><mo>(</mo><msup><mi>w</mi><mi>T</mi></msup><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msup><mi>w</mi><mi>T</mi></msup><msub><mi>μ</mi><mo>−</mo></msub><msup><mo>)</mo><mn>2</mn></msup></mrow><mrow><mi>n</mi><mo>−</mo><mn>1</mn></mrow></mfrac></mrow><annotation encoding="application/x-tex">S_W(w)=\frac{\sum_x(w^Tx_i-w^T\mu_+)^2+\sum_x(w^Tx_i-w^T\mu_-)^2}{n-1} </annotation></semantics></math>SW(w)=n−1∑x(wTxi−wTμ+)2+∑x(wTxi−wTμ−)2
<math><semantics><mrow><mo>=</mo><mfrac><mrow><msub><mo>∑</mo><mi>x</mi></msub><mo>(</mo><msup><mi>w</mi><mi>T</mi></msup><mo>(</mo><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>)</mo><msup><mo>)</mo><mn>2</mn></msup><mo>+</mo><msub><mo>∑</mo><mi>x</mi></msub><mo>(</mo><msup><mi>w</mi><mi>T</mi></msup><mo>(</mo><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><mo>)</mo><msup><mo>)</mo><mn>2</mn></msup></mrow><mrow><mi>n</mi><mo>−</mo><mn>1</mn></mrow></mfrac></mrow><annotation encoding="application/x-tex">=\frac{\sum_x(w^T(x_i-\mu_+))^2+\sum_x(w^T(x_i-\mu_-))^2}{n-1} </annotation></semantics></math>=n−1∑x(wT(xi−μ+))2+∑x(wT(xi−μ−))2
<math><semantics><mrow><mo>=</mo><mfrac><mrow><msub><mo>∑</mo><mi>x</mi></msub><msup><mi>w</mi><mi>T</mi></msup><mo>(</mo><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>)</mo><mo>(</mo><msup><mi>w</mi><mi>T</mi></msup><mo>(</mo><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>)</mo><msup><mo>)</mo><mi>T</mi></msup><mo>+</mo><msub><mo>∑</mo><mi>x</mi></msub><msup><mi>w</mi><mi>T</mi></msup><mo>(</mo><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><mo>)</mo><mo>(</mo><msup><mi>w</mi><mi>T</mi></msup><mo>(</mo><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><mo>)</mo><msup><mo>)</mo><mi>T</mi></msup></mrow><mrow><mi>n</mi><mo>−</mo><mn>1</mn></mrow></mfrac></mrow><annotation encoding="application/x-tex">=\frac{\sum_xw^T(x_i-\mu_+)(w^T(x_i-\mu_+))^T+\sum_xw^T(x_i-\mu_-)(w^T(x_i-\mu_-))^T}{n-1} </annotation></semantics></math>=n−1∑xwT(xi−μ+)(wT(xi−μ+))T+∑xwT(xi−μ−)(wT(xi−μ−))T
<math><semantics><mrow><mo>=</mo><mfrac><mrow><msup><mi>w</mi><mi>T</mi></msup><msub><mo>∑</mo><mi>x</mi></msub><mo>(</mo><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>)</mo><mo>(</mo><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msub><mi>μ</mi><mo>+</mo></msub><msup><mo>)</mo><mi>T</mi></msup><mi>w</mi><mo>+</mo><msup><mi>w</mi><mi>T</mi></msup><msub><mo>∑</mo><mi>x</mi></msub><mo>(</mo><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><mo>)</mo><mo>(</mo><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><msup><mo>)</mo><mi>T</mi></msup><mi>w</mi></mrow><mrow><mi>n</mi><mo>−</mo><mn>1</mn></mrow></mfrac></mrow><annotation encoding="application/x-tex">=\frac{w^T\sum_x(x_i-\mu_+)(x_i-\mu_+)^Tw+w^T\sum_x(x_i-\mu_-)(x_i-\mu_-)^Tw}{n-1} </annotation></semantics></math>=n−1wT∑x(xi−μ+)(xi−μ+)Tw+wT∑x(xi−μ−)(xi−μ−)Tw
其中
<math><semantics><mrow><mfrac><mrow><msub><mo>∑</mo><mi>x</mi></msub><mo>(</mo><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>)</mo><mo>(</mo><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msub><mi>μ</mi><mo>+</mo></msub><msup><mo>)</mo><mi>T</mi></msup></mrow><mrow><mi>n</mi><mo>−</mo><mn>1</mn></mrow></mfrac><mo>=</mo><msub><mi mathvariant="normal">Σ</mi><mo>+</mo></msub></mrow><annotation encoding="application/x-tex">\frac{\sum_x(x_i-\mu_+)(x_i-\mu_+)^T}{n-1} = \Sigma_+ </annotation></semantics></math>n−1∑x(xi−μ+)(xi−μ+)T=Σ+
是正类的协方差矩阵,注意
<math><semantics><mrow><mi>x</mi><mo>(</mo><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>)</mo></mrow><annotation encoding="application/x-tex">x(x_i-\mu_+) </annotation></semantics></math>x(xi−μ+)
是列向量,所以协方差是一个长宽等于数据维度的方阵。
最后:
<math><semantics><mrow><msub><mi>S</mi><mi>W</mi></msub><mo>(</mo><mi>w</mi><mo>)</mo><mo>=</mo><msup><mi>w</mi><mi>T</mi></msup><msub><mi mathvariant="normal">Σ</mi><mo>+</mo></msub><mi>w</mi><mo>+</mo><msup><mi>w</mi><mi>T</mi></msup><msub><mi mathvariant="normal">Σ</mi><mo>−</mo></msub><mi>w</mi></mrow><annotation encoding="application/x-tex">S_W(w)=w^T\Sigma_+w+w^T\Sigma_-w </annotation></semantics></math>SW(w)=wTΣ+w+wTΣ−w
优化目标
线性判别式的总目标就是最大化类间距离,最小化类内方差,类似于聚类:
\mathop{\arg\max}\limits_{w} J(w) = \frac{S_B(w)}{S_W(w)}<math><semantics><mrow><mo>=</mo><mfrac><mrow><mo>(</mo><msup><mi>w</mi><mi>T</mi></msup><msub><mi>μ</mi><mo>+</mo></msub><mo>−</mo><msup><mi>w</mi><mi>T</mi></msup><msub><mi>μ</mi><mo>−</mo></msub><msup><mo>)</mo><mn>2</mn></msup></mrow><mrow><msup><mi>w</mi><mi>T</mi></msup><msub><mi mathvariant="normal">Σ</mi><mo>+</mo></msub><mi>w</mi><mo>+</mo><msup><mi>w</mi><mi>T</mi></msup><msub><mi mathvariant="normal">Σ</mi><mo>−</mo></msub><mi>w</mi></mrow></mfrac></mrow><annotation encoding="application/x-tex">=\frac{(w^T\mu_+-w^T\mu_-)^2}{w^T\Sigma_+w+w^T\Sigma_-w} </annotation></semantics></math>=wTΣ+w+wTΣ−w(wTμ+−wTμ−)2
<math><semantics><mrow><mo>=</mo><mfrac><mrow><msup><mi>w</mi><mi>T</mi></msup><mo>(</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><mo>)</mo><mo>(</mo><msup><mi>w</mi><mi>T</mi></msup><mo>(</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><mo>)</mo><msup><mo>)</mo><mi>T</mi></msup></mrow><mrow><msup><mi>w</mi><mi>T</mi></msup><mo>(</mo><msub><mi mathvariant="normal">Σ</mi><mo>+</mo></msub><mo>−</mo><msub><mi mathvariant="normal">Σ</mi><mo>−</mo></msub><mo>)</mo><mi>w</mi></mrow></mfrac></mrow><annotation encoding="application/x-tex">= \frac{w^T(\mu_+-\mu_-)(w^T(\mu_+-\mu_-))^T}{w^T(\Sigma_+-\Sigma_-)w} </annotation></semantics></math>=wT(Σ+−Σ−)wwT(μ+−μ−)(wT(μ+−μ−))T
<math><semantics><mrow><mo>=</mo><mfrac><mrow><msup><mi>w</mi><mi>T</mi></msup><mo>(</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><mo>)</mo><mo>(</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><msup><mo>)</mo><mi>T</mi></msup><mi>w</mi></mrow><mrow><msup><mi>w</mi><mi>T</mi></msup><mo>(</mo><msub><mi mathvariant="normal">Σ</mi><mo>+</mo></msub><mo>−</mo><msub><mi mathvariant="normal">Σ</mi><mo>−</mo></msub><mo>)</mo><mi>w</mi></mrow></mfrac></mrow><annotation encoding="application/x-tex">= \frac{w^T(\mu_+-\mu_-)(\mu_+-\mu_-)^Tw}{w^T(\Sigma_+-\Sigma_-)w} </annotation></semantics></math>=wT(Σ+−Σ−)wwT(μ+−μ−)(μ+−μ−)Tw
看到这个形式,我们根据上一篇文档的知识知道这个可以使用广义瑞利商来求极大值。
广义瑞利商
**背景介绍及推导见(瑞利商(Rayleigh quotient)与广义瑞利商(genralized Rayleigh quotient)
**
下面只摘抄一些:
广义瑞利商是指这样的函数𝑅(𝐴,𝐵,𝑥):
R(A,B,x) = \cfrac{X^{H}Ax}{X^{H}Bx}其中𝑥为非零向量,而𝐴,𝐵为𝑛×𝑛的Hermitan矩阵。𝐵为正定矩阵。
令
<math><semantics><mrow><mi>A</mi><mo>=</mo><mo>(</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><mo>)</mo><mo>(</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><msup><mo>)</mo><mi>T</mi></msup></mrow><annotation encoding="application/x-tex">A=(\mu_+-\mu_-)(\mu_+-\mu_-)^T </annotation></semantics></math>A=(μ+−μ−)(μ+−μ−)T
<math><semantics><mrow><mi>B</mi><mo>=</mo><msub><mi mathvariant="normal">Σ</mi><mo>+</mo></msub><mo>−</mo><msub><mi mathvariant="normal">Σ</mi><mo>−</mo></msub></mrow><annotation encoding="application/x-tex">B= \Sigma_+-\Sigma_- </annotation></semantics></math>B=Σ+−Σ−
\mathop{\arg\max}\limits_{w} J(w) = \frac{w^TAw}{w^TBw}这个就很广义瑞利商了。
至于w的值,使用拉格朗日乘子法可以求解得到:
<math><semantics><mrow><msup><mi>B</mi><mrow><mo>−</mo><mn>1</mn></mrow></msup><mi>A</mi><mi>w</mi><mo>=</mo><mi>λ</mi><mi>w</mi></mrow><annotation encoding="application/x-tex">B^{-1}Aw = \lambda w </annotation></semantics></math>B−1Aw=λw
<math><semantics><mrow><msup><mi>B</mi><mrow><mo>−</mo><mn>1</mn></mrow></msup><mo>(</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><mo>)</mo><mo>(</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><msup><mo>)</mo><mi>T</mi></msup><mi>w</mi><mo>=</mo><mi>λ</mi><mi>w</mi></mrow><annotation encoding="application/x-tex">B^{-1}(\mu_+-\mu_-)(\mu_+-\mu_-)^Tw = \lambda w </annotation></semantics></math>B−1(μ+−μ−)(μ+−μ−)Tw=λw
由于
<math><semantics><mrow><mo>(</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><msup><mo>)</mo><mi>T</mi></msup><mi>w</mi></mrow><annotation encoding="application/x-tex">(\mu_+-\mu_-)^Tw </annotation></semantics></math>(μ+−μ−)Tw
是行向量乘列向量,所以结果是一个标量,
那我们知道:
<math><semantics><mrow><msup><mi>B</mi><mrow><mo>−</mo><mn>1</mn></mrow></msup><mo>(</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><mo>)</mo><mo>∝</mo><mi>λ</mi><mi>w</mi></mrow><annotation encoding="application/x-tex">B^{-1}(\mu_+-\mu_-) \propto \lambda w </annotation></semantics></math>B−1(μ+−μ−)∝λw
<math><semantics><mrow><mo>(</mo><msub><mi mathvariant="normal">Σ</mi><mo>+</mo></msub><mo>−</mo><msub><mi mathvariant="normal">Σ</mi><mo>−</mo></msub><msup><mo>)</mo><mrow><mo>−</mo><mn>1</mn></mrow></msup><mo>(</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><mo>)</mo><mo>∝</mo><mi>w</mi></mrow><annotation encoding="application/x-tex">(\Sigma_+-\Sigma_-)^{-1}(\mu_+-\mu_-) \propto w </annotation></semantics></math>(Σ+−Σ−)−1(μ+−μ−)∝w
由于w我们只关注方向而不是长度,所以可以认为:
<math><semantics><mrow><msub><mi>w</mi><mrow><mi>b</mi><mi>e</mi><mi>s</mi><mi>t</mi></mrow></msub><mo>=</mo><mo>(</mo><msub><mi mathvariant="normal">Σ</mi><mo>+</mo></msub><mo>−</mo><msub><mi mathvariant="normal">Σ</mi><mo>−</mo></msub><msup><mo>)</mo><mrow><mo>−</mo><mn>1</mn></mrow></msup><mo>(</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><mo>)</mo></mrow><annotation encoding="application/x-tex">w_{best} =(\Sigma_+-\Sigma_-)^{-1}(\mu_+-\mu_-) </annotation></semantics></math>wbest=(Σ+−Σ−)−1(μ+−μ−)