登录
原创

线性判别分析 Linear Discriminant Analysis,LDA

专栏苏州谷歌开发者社区
发布于 2020-11-20 阅读 40
  • 机器学习
原创

线性判别分类器由向量<math><semantics><mrow><mi>w</mi></mrow><annotation encoding="application/x-tex">w</annotation></semantics></math>w和偏差项<math><semantics><mrow><mi>b</mi></mrow><annotation encoding="application/x-tex">b</annotation></semantics></math>b构成。给定样例<math><semantics><mrow><mi>x</mi></mrow><annotation encoding="application/x-tex">x</annotation></semantics></math>x,其按照如下规则预测获得类别标记<math><semantics><mrow><mi>y</mi></mrow><annotation encoding="application/x-tex">y</annotation></semantics></math>y,即
<math><semantics><mrow><mi>y</mi><mo>=</mo><mi>s</mi><mi>i</mi><mi>g</mi><mi>n</mi><mo>(</mo><msup><mi>w</mi><mi>T</mi></msup><mi>x</mi><mo>+</mo><mi>b</mi><mo>)</mo></mrow><annotation encoding="application/x-tex">y=sign(w^Tx+b)</annotation></semantics></math>y=sign(wTx+b)
后面统一使用小写表示列向量,转置表示行向量。
分类过程分为如下两步:

  • 首先,使用权重向量w将样本空间投影到直线上去
  • 然后,寻找直线上一个点把正样本和负样本分开。

为了寻找最有的线性分类器,即<math><semantics><mrow><mi>w</mi></mrow><annotation encoding="application/x-tex">w</annotation></semantics></math>w<math><semantics><mrow><mi>b</mi></mrow><annotation encoding="application/x-tex">b</annotation></semantics></math>b,一个经典的学习算法是线性判别分析(Fisher’s Linear Discriminant Analysis,LDA)。

简要来说,LDA的基本想法是使不同的样本尽量原理,使同类样本尽量靠近。

这一目标可以通过扩大不同类样本的类中心距离,同时缩小每个类的类内方差来实现。

在一个二分类数据集上,分别记所有正样本的的均值为<math><semantics><mrow><msub><mi>μ</mi><mo>+</mo></msub></mrow><annotation encoding="application/x-tex">\mu_+</annotation></semantics></math>μ+,协方差矩阵为<math><semantics><mrow><msub><mi mathvariant="normal">Σ</mi><mo>+</mo></msub></mrow><annotation encoding="application/x-tex">\Sigma_+</annotation></semantics></math>Σ+;所有负样本的的均值为<math><semantics><mrow><msub><mi>μ</mi><mo>−</mo></msub></mrow><annotation encoding="application/x-tex">\mu_-</annotation></semantics></math>μ,协方差矩阵为<math><semantics><mrow><msub><mi mathvariant="normal">Σ</mi><mo>−</mo></msub></mrow><annotation encoding="application/x-tex">\Sigma_-</annotation></semantics></math>Σ

类间距离

投影后的类中心间距离为正类中心的投影点值减去负类投影点值:

<math><semantics><mrow><msub><mi>S</mi><mi>B</mi></msub><mo>(</mo><mi>w</mi><mo>)</mo><mo>=</mo><mo>(</mo><msup><mi>w</mi><mi>T</mi></msup><msub><mi>μ</mi><mo>+</mo></msub><mo>−</mo><msup><mi>w</mi><mi>T</mi></msup><msub><mi>μ</mi><mo>−</mo></msub><msup><mo>)</mo><mn>2</mn></msup></mrow><annotation encoding="application/x-tex">S_B(w)=(w^T\mu_+-w^T\mu_-)^2 </annotation></semantics></math>SB(w)=(wTμ+wTμ)2

类内距离

同时,类内方差可写为:

<math><semantics><mrow><msub><mi>S</mi><mi>W</mi></msub><mo>(</mo><mi>w</mi><mo>)</mo><mo>=</mo><mfrac><mrow><msub><mo>∑</mo><mi>x</mi></msub><mo>(</mo><msup><mi>w</mi><mi>T</mi></msup><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msup><mi>w</mi><mi>T</mi></msup><msub><mi>μ</mi><mo>+</mo></msub><msup><mo>)</mo><mn>2</mn></msup><mo>+</mo><msub><mo>∑</mo><mi>x</mi></msub><mo>(</mo><msup><mi>w</mi><mi>T</mi></msup><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msup><mi>w</mi><mi>T</mi></msup><msub><mi>μ</mi><mo>−</mo></msub><msup><mo>)</mo><mn>2</mn></msup></mrow><mrow><mi>n</mi><mo>−</mo><mn>1</mn></mrow></mfrac></mrow><annotation encoding="application/x-tex">S_W(w)=\frac{\sum_x(w^Tx_i-w^T\mu_+)^2+\sum_x(w^Tx_i-w^T\mu_-)^2}{n-1} </annotation></semantics></math>SW(w)=n1x(wTxiwTμ+)2+x(wTxiwTμ)2

<math><semantics><mrow><mo>=</mo><mfrac><mrow><msub><mo>∑</mo><mi>x</mi></msub><mo>(</mo><msup><mi>w</mi><mi>T</mi></msup><mo>(</mo><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>)</mo><msup><mo>)</mo><mn>2</mn></msup><mo>+</mo><msub><mo>∑</mo><mi>x</mi></msub><mo>(</mo><msup><mi>w</mi><mi>T</mi></msup><mo>(</mo><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><mo>)</mo><msup><mo>)</mo><mn>2</mn></msup></mrow><mrow><mi>n</mi><mo>−</mo><mn>1</mn></mrow></mfrac></mrow><annotation encoding="application/x-tex">=\frac{\sum_x(w^T(x_i-\mu_+))^2+\sum_x(w^T(x_i-\mu_-))^2}{n-1} </annotation></semantics></math>=n1x(wT(xiμ+))2+x(wT(xiμ))2

<math><semantics><mrow><mo>=</mo><mfrac><mrow><msub><mo>∑</mo><mi>x</mi></msub><msup><mi>w</mi><mi>T</mi></msup><mo>(</mo><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>)</mo><mo>(</mo><msup><mi>w</mi><mi>T</mi></msup><mo>(</mo><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>)</mo><msup><mo>)</mo><mi>T</mi></msup><mo>+</mo><msub><mo>∑</mo><mi>x</mi></msub><msup><mi>w</mi><mi>T</mi></msup><mo>(</mo><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><mo>)</mo><mo>(</mo><msup><mi>w</mi><mi>T</mi></msup><mo>(</mo><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><mo>)</mo><msup><mo>)</mo><mi>T</mi></msup></mrow><mrow><mi>n</mi><mo>−</mo><mn>1</mn></mrow></mfrac></mrow><annotation encoding="application/x-tex">=\frac{\sum_xw^T(x_i-\mu_+)(w^T(x_i-\mu_+))^T+\sum_xw^T(x_i-\mu_-)(w^T(x_i-\mu_-))^T}{n-1} </annotation></semantics></math>=n1xwT(xiμ+)(wT(xiμ+))T+xwT(xiμ)(wT(xiμ))T

<math><semantics><mrow><mo>=</mo><mfrac><mrow><msup><mi>w</mi><mi>T</mi></msup><msub><mo>∑</mo><mi>x</mi></msub><mo>(</mo><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>)</mo><mo>(</mo><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msub><mi>μ</mi><mo>+</mo></msub><msup><mo>)</mo><mi>T</mi></msup><mi>w</mi><mo>+</mo><msup><mi>w</mi><mi>T</mi></msup><msub><mo>∑</mo><mi>x</mi></msub><mo>(</mo><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><mo>)</mo><mo>(</mo><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><msup><mo>)</mo><mi>T</mi></msup><mi>w</mi></mrow><mrow><mi>n</mi><mo>−</mo><mn>1</mn></mrow></mfrac></mrow><annotation encoding="application/x-tex">=\frac{w^T\sum_x(x_i-\mu_+)(x_i-\mu_+)^Tw+w^T\sum_x(x_i-\mu_-)(x_i-\mu_-)^Tw}{n-1} </annotation></semantics></math>=n1wTx(xiμ+)(xiμ+)Tw+wTx(xiμ)(xiμ)Tw

其中

<math><semantics><mrow><mfrac><mrow><msub><mo>∑</mo><mi>x</mi></msub><mo>(</mo><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>)</mo><mo>(</mo><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msub><mi>μ</mi><mo>+</mo></msub><msup><mo>)</mo><mi>T</mi></msup></mrow><mrow><mi>n</mi><mo>−</mo><mn>1</mn></mrow></mfrac><mo>=</mo><msub><mi mathvariant="normal">Σ</mi><mo>+</mo></msub></mrow><annotation encoding="application/x-tex">\frac{\sum_x(x_i-\mu_+)(x_i-\mu_+)^T}{n-1} = \Sigma_+ </annotation></semantics></math>n1x(xiμ+)(xiμ+)T=Σ+

是正类的协方差矩阵,注意

<math><semantics><mrow><mi>x</mi><mo>(</mo><msub><mi>x</mi><mi>i</mi></msub><mo>−</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>)</mo></mrow><annotation encoding="application/x-tex">x(x_i-\mu_+) </annotation></semantics></math>x(xiμ+)

是列向量,所以协方差是一个长宽等于数据维度的方阵。

最后:

<math><semantics><mrow><msub><mi>S</mi><mi>W</mi></msub><mo>(</mo><mi>w</mi><mo>)</mo><mo>=</mo><msup><mi>w</mi><mi>T</mi></msup><msub><mi mathvariant="normal">Σ</mi><mo>+</mo></msub><mi>w</mi><mo>+</mo><msup><mi>w</mi><mi>T</mi></msup><msub><mi mathvariant="normal">Σ</mi><mo>−</mo></msub><mi>w</mi></mrow><annotation encoding="application/x-tex">S_W(w)=w^T\Sigma_+w+w^T\Sigma_-w </annotation></semantics></math>SW(w)=wTΣ+w+wTΣw

优化目标

线性判别式的总目标就是最大化类间距离,最小化类内方差,类似于聚类:

\mathop{\arg\max}\limits_{w} J(w) = \frac{S_B(w)}{S_W(w)}

<math><semantics><mrow><mo>=</mo><mfrac><mrow><mo>(</mo><msup><mi>w</mi><mi>T</mi></msup><msub><mi>μ</mi><mo>+</mo></msub><mo>−</mo><msup><mi>w</mi><mi>T</mi></msup><msub><mi>μ</mi><mo>−</mo></msub><msup><mo>)</mo><mn>2</mn></msup></mrow><mrow><msup><mi>w</mi><mi>T</mi></msup><msub><mi mathvariant="normal">Σ</mi><mo>+</mo></msub><mi>w</mi><mo>+</mo><msup><mi>w</mi><mi>T</mi></msup><msub><mi mathvariant="normal">Σ</mi><mo>−</mo></msub><mi>w</mi></mrow></mfrac></mrow><annotation encoding="application/x-tex">=\frac{(w^T\mu_+-w^T\mu_-)^2}{w^T\Sigma_+w+w^T\Sigma_-w} </annotation></semantics></math>=wTΣ+w+wTΣw(wTμ+wTμ)2

<math><semantics><mrow><mo>=</mo><mfrac><mrow><msup><mi>w</mi><mi>T</mi></msup><mo>(</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><mo>)</mo><mo>(</mo><msup><mi>w</mi><mi>T</mi></msup><mo>(</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><mo>)</mo><msup><mo>)</mo><mi>T</mi></msup></mrow><mrow><msup><mi>w</mi><mi>T</mi></msup><mo>(</mo><msub><mi mathvariant="normal">Σ</mi><mo>+</mo></msub><mo>−</mo><msub><mi mathvariant="normal">Σ</mi><mo>−</mo></msub><mo>)</mo><mi>w</mi></mrow></mfrac></mrow><annotation encoding="application/x-tex">= \frac{w^T(\mu_+-\mu_-)(w^T(\mu_+-\mu_-))^T}{w^T(\Sigma_+-\Sigma_-)w} </annotation></semantics></math>=wT(Σ+Σ)wwT(μ+μ)(wT(μ+μ))T

<math><semantics><mrow><mo>=</mo><mfrac><mrow><msup><mi>w</mi><mi>T</mi></msup><mo>(</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><mo>)</mo><mo>(</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><msup><mo>)</mo><mi>T</mi></msup><mi>w</mi></mrow><mrow><msup><mi>w</mi><mi>T</mi></msup><mo>(</mo><msub><mi mathvariant="normal">Σ</mi><mo>+</mo></msub><mo>−</mo><msub><mi mathvariant="normal">Σ</mi><mo>−</mo></msub><mo>)</mo><mi>w</mi></mrow></mfrac></mrow><annotation encoding="application/x-tex">= \frac{w^T(\mu_+-\mu_-)(\mu_+-\mu_-)^Tw}{w^T(\Sigma_+-\Sigma_-)w} </annotation></semantics></math>=wT(Σ+Σ)wwT(μ+μ)(μ+μ)Tw

看到这个形式,我们根据上一篇文档的知识知道这个可以使用广义瑞利商来求极大值

广义瑞利商

**背景介绍及推导见(瑞利商(Rayleigh quotient)与广义瑞利商(genralized Rayleigh quotient)
**
下面只摘抄一些:

广义瑞利商是指这样的函数𝑅(𝐴,𝐵,𝑥):

R(A,B,x) = \cfrac{X^{H}Ax}{X^{H}Bx}

其中𝑥为非零向量,而𝐴,𝐵为𝑛×𝑛的Hermitan矩阵。𝐵为正定矩阵

<math><semantics><mrow><mi>A</mi><mo>=</mo><mo>(</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><mo>)</mo><mo>(</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><msup><mo>)</mo><mi>T</mi></msup></mrow><annotation encoding="application/x-tex">A=(\mu_+-\mu_-)(\mu_+-\mu_-)^T </annotation></semantics></math>A=(μ+μ)(μ+μ)T

<math><semantics><mrow><mi>B</mi><mo>=</mo><msub><mi mathvariant="normal">Σ</mi><mo>+</mo></msub><mo>−</mo><msub><mi mathvariant="normal">Σ</mi><mo>−</mo></msub></mrow><annotation encoding="application/x-tex">B= \Sigma_+-\Sigma_- </annotation></semantics></math>B=Σ+Σ

\mathop{\arg\max}\limits_{w} J(w) = \frac{w^TAw}{w^TBw}

这个就很广义瑞利商了。

至于w的值,使用拉格朗日乘子法可以求解得到:

<math><semantics><mrow><msup><mi>B</mi><mrow><mo>−</mo><mn>1</mn></mrow></msup><mi>A</mi><mi>w</mi><mo>=</mo><mi>λ</mi><mi>w</mi></mrow><annotation encoding="application/x-tex">B^{-1}Aw = \lambda w </annotation></semantics></math>B1Aw=λw

<math><semantics><mrow><msup><mi>B</mi><mrow><mo>−</mo><mn>1</mn></mrow></msup><mo>(</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><mo>)</mo><mo>(</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><msup><mo>)</mo><mi>T</mi></msup><mi>w</mi><mo>=</mo><mi>λ</mi><mi>w</mi></mrow><annotation encoding="application/x-tex">B^{-1}(\mu_+-\mu_-)(\mu_+-\mu_-)^Tw = \lambda w </annotation></semantics></math>B1(μ+μ)(μ+μ)Tw=λw

由于

<math><semantics><mrow><mo>(</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><msup><mo>)</mo><mi>T</mi></msup><mi>w</mi></mrow><annotation encoding="application/x-tex">(\mu_+-\mu_-)^Tw </annotation></semantics></math>(μ+μ)Tw

是行向量乘列向量,所以结果是一个标量,
那我们知道:

<math><semantics><mrow><msup><mi>B</mi><mrow><mo>−</mo><mn>1</mn></mrow></msup><mo>(</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><mo>)</mo><mo>∝</mo><mi>λ</mi><mi>w</mi></mrow><annotation encoding="application/x-tex">B^{-1}(\mu_+-\mu_-) \propto \lambda w </annotation></semantics></math>B1(μ+μ)λw

<math><semantics><mrow><mo>(</mo><msub><mi mathvariant="normal">Σ</mi><mo>+</mo></msub><mo>−</mo><msub><mi mathvariant="normal">Σ</mi><mo>−</mo></msub><msup><mo>)</mo><mrow><mo>−</mo><mn>1</mn></mrow></msup><mo>(</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><mo>)</mo><mo>∝</mo><mi>w</mi></mrow><annotation encoding="application/x-tex">(\Sigma_+-\Sigma_-)^{-1}(\mu_+-\mu_-) \propto w </annotation></semantics></math>(Σ+Σ)1(μ+μ)w

由于w我们只关注方向而不是长度,所以可以认为:

<math><semantics><mrow><msub><mi>w</mi><mrow><mi>b</mi><mi>e</mi><mi>s</mi><mi>t</mi></mrow></msub><mo>=</mo><mo>(</mo><msub><mi mathvariant="normal">Σ</mi><mo>+</mo></msub><mo>−</mo><msub><mi mathvariant="normal">Σ</mi><mo>−</mo></msub><msup><mo>)</mo><mrow><mo>−</mo><mn>1</mn></mrow></msup><mo>(</mo><msub><mi>μ</mi><mo>+</mo></msub><mo>−</mo><msub><mi>μ</mi><mo>−</mo></msub><mo>)</mo></mrow><annotation encoding="application/x-tex">w_{best} =(\Sigma_+-\Sigma_-)^{-1}(\mu_+-\mu_-) </annotation></semantics></math>wbest=(Σ+Σ)1(μ+μ)

教科书上的LDA为什么长这样?
线性判别分析LDA原理总结

评论区

PhD Candidate in Machine Learning

0

0

0