机器学习之Normal Equation求解

接触到正规方程组(Normal Equation)是线性回归中对模型参数的求解。另外一篇笔记【机器学习之线性回归】已经记录了线性回归的知识点,这篇笔记着重讨论Normal Equation的求解。吴恩达老师在机器学习公开课视频中提到他会经常使用他所讲述的方法进行一些推导,而且步骤相当简单,所以有必要仔细学习他讲述的方法。

线性回归模型

对于特征个数为n的模型,$\boldsymbol{x} = \left[ x_0, x_1, \cdots, x_n \right]^\text{T} \in \mathbb{R}^{(n+1) \times 1}$ 表示输入,$\boldsymbol{\theta} = \left[ \theta_0, \theta_1, \cdots, \theta_n \right]^\text{T} \in \mathbb{R}^{(n+1) \times 1}$ 表示模型参数,则模型的输出为:

显然,$h_\boldsymbol{\theta}(\boldsymbol{x})$ 是一个实数。同时,记输入 $\boldsymbol{x}$ 的真实输出为 $y$。则对于输入 $\boldsymbol{x}$,模型的平方误差为:

令样本数量为m,当 $ 1 \leq i \leq m$ 时,$\boldsymbol{x}_i = \left[ x_{i0}, x_{i1}, \cdots, x_{in} \right]^\text{T} \in \mathbb{R}^{(n+1) \times 1}$ 表示第i个样本的输入,$y_i$ 表示第i个样本的真实输出,记均方误差为:

我们希望模型参数 $\boldsymbol{\theta}$ 能使均方误差最小,最常用的求解方法就是梯度下降(Gradient Descent)正规方程组(Normal Equation)

通俗的说,Normal Equation就是求导,而导数等于零的点就是极值点(通过泰勒展开可以推导)。接下来的所有工作无非就是对均方误差函数的求导。

均方误差的矩阵形式

先将均方误差写作矩阵的形式。

令 $\boldsymbol{w} = \left[ w_1, w_2, \cdots, w_m \right]^{\text{T}} $,其中 $ w_i = h_\boldsymbol{\theta}(\boldsymbol{x}_i) - y_i $。有

又有

所以

求导

其中,$\boldsymbol{\theta}^T\boldsymbol{X}^T\boldsymbol{X}\boldsymbol{\theta}$ 、$\boldsymbol{\theta}^T\boldsymbol{X}^T \boldsymbol{y}$、$\boldsymbol{y}^T\boldsymbol{X}\boldsymbol{\theta}$ 都是标量(1×1的矩阵),所以可以分解成三个【标量对向量求导】的问题。

一般方法

分别对上述三个标量求导:

其中,$\frac{\partial}{\partial \theta_k} \sum \limits_{i=1}^{n} \sum \limits_{j=1}^{n}\theta_i w_{ij} \theta_j$ 的求导需要详细叙述一下,注意到$w_{ik} = w_{ki}$,求导过程如下:

所以式(1)等于

利用迹求导

吴恩达老师在求导过程中引入了迹的求导,这里的推导与吴恩达老师使用的公式稍微有些不同,但最终效果应该是一样的。

n×n矩阵 $\boldsymbol{A} = \left[ a_{ij} \right]$ 的迹定义为:

迹只对方阵有意义,非方阵矩阵无迹的定义。

关于迹的部分等式:

  1. $ a \in \mathbb{R}^{1 \times 1}$
  2. $ \boldsymbol{A} \in \mathbb{R}^{n \times n}$
  3. $ \boldsymbol{A} \in \mathbb{R}^{m \times n}$、$ \boldsymbol{B} \in \mathbb{R}^{n \times m}$
  4. $ \boldsymbol{A} \in \mathbb{R}^{m \times n}$、$ \boldsymbol{B} \in \mathbb{R}^{n \times r}$、$ \boldsymbol{C} \in \mathbb{R}^{r \times m}$

关于迹的求导:

  1. $ \boldsymbol{A} \in \mathbb{R}^{n \times n}$
  2. $ \boldsymbol{A} \in \mathbb{R}^{m \times n}$、$ \boldsymbol{B} \in \mathbb{R}^{n \times m}$
  3. $ \boldsymbol{A} \in \mathbb{R}^{m \times n}$、$ \boldsymbol{B} \in \mathbb{R}^{m \times n}$
  4. $ \boldsymbol{A} \in \mathbb{R}^{m \times n}$、$ \boldsymbol{B} \in \mathbb{R}^{m \times m}$
  5. $ \boldsymbol{A} \in \mathbb{R}^{m \times n}$、$ \boldsymbol{B} \in \mathbb{R}^{n \times n}$

根据式(2)可以将式(1)表示成对迹求导,然后依据式(9)、式(8)、式(7)可以方便的推出求导的结果:

令导数等于零,就可以得到Normal Equation了:

进而可以得到 $\boldsymbol{\theta}$ 的解:

参考资料

  • 吴恩达机器学习视频
  • 矩阵分析与应用 张贤达