手推公式之“层归一化”梯度

昨天推导了一下交叉熵的反向传播梯度,今天再来推导一下层归一化(LayerNorm),这是一种常见的归一化方法。

前向传播

假设待归一化的$m$维向量为$x$,均值和标准差分别是$\mu{(x)}$和$\sigma{(x)}$,LayerNorm的参数是$w$和$b$,那么层归一化后的输出为:
$$
y = w \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + b
$$

这里的极小量$\epsilon$是为了防止标准差为0导致溢出。为了简单起见,我们不加上这一项,原公式也就变成了:
$$
y = w \odot \frac{x - \mu}{\sigma} + b
$$

反向传播

假设损失函数$\mathcal{L}$对输出$y$的梯度是$\frac{\partial{\mathcal{L}}}{\partial{y}}$,那么这里需要求三个梯度:$\frac{\partial{\mathcal{L}}}{\partial{w}}$、$\frac{\partial{\mathcal{L}}}{\partial{b}}$和$\frac{\partial{\mathcal{L}}}{\partial{x}}$。

令$\hat x = \frac{x - \mu}{\sigma}$,那么原公式就变成了:
$$
y = w \odot \hat x + b
$$

两个参数的梯度非常简单:
$$
\begin{aligned}
\frac{\partial{\mathcal{L}}}{\partial{w_i}} &= \frac{\partial{\mathcal{L}}}{\partial{y_i}} \cdot \frac{\partial{\mathcal{y_i}}}{\partial{w_i}} = \frac{\partial{\mathcal{L}}}{\partial{y_i}} \cdot \hat x_i \\
\frac{\partial{\mathcal{L}}}{\partial{b_i}} &= \frac{\partial{\mathcal{L}}}{\partial{y_i}} \cdot \frac{\partial{\mathcal{y_i}}}{\partial{b_i}} = \frac{\partial{\mathcal{L}}}{\partial{y_i}}
\end{aligned}
$$

对输入$x$的梯度等于:
$$
\frac{\partial{\mathcal{L}}}{\partial{x_i}} = \frac{1}{\sigma} \cdot \left[\frac{\partial{\mathcal{L}}}{\partial{y_i}} \cdot w_i - \frac{1}{m}\cdot \left( \sum_j {\frac{\partial{\mathcal{L}}}{\partial{y_j}} \cdot w_j} + \hat x_i \cdot \sum_j {\frac{\partial{\mathcal{L}}}{\partial{y_j}} \cdot w_j \cdot \hat x_j}\right)\right]
$$

推导过程

对输入$x$的梯度可以写成:
$$
\begin{aligned}
\frac{\partial{\mathcal{L}}}{\partial{x_i}} &= \sum_j {\frac{\partial{\mathcal{L}}}{\partial{y_j}} \cdot \frac{\partial{\mathcal{y_j}}}{\partial{\hat x_j}}} \cdot \frac{\partial{\mathcal{\hat x_j}}}{\partial{x_i}} \\
&= \sum_j {\frac{\partial{\mathcal{L}}}{\partial{y_j}} \cdot w_j \cdot \frac{\partial{\mathcal{\hat x_j}}}{\partial{x_i}}}
\end{aligned}
$$

这里只需要计算最后一项就行了:
$$
\begin{aligned}
\frac{\partial{\mathcal{\hat x_j}}}{\partial{x_i}} &= \frac{\partial}{\partial{x_i}}\left(\frac{x_j - \mu}{\sigma}\right) \\
&= \frac{\partial{(x_j - \mu)}}{\partial x_i} \cdot \sigma^{-1} + (x_j - \mu) \cdot \frac{\partial\sigma}{\partial x_i} \cdot (-\sigma^{-2}) \\
&= \left(\delta_{ij} - \frac{\partial{\mu}}{\partial x_i}\right) \cdot \sigma^{-1} - \sigma^{-2} \cdot (x_j - \mu) \cdot \frac{\partial\sigma}{\partial x_i}
\end{aligned}
$$

其中$\delta_{ij}$只有当$i=j$的时候才会等于1,否则都等于0。这里只需要求出均值和标准差对$x_i$的梯度就行了。直接给出结论,证明比较简单,放在了文末:
$$
\begin{aligned}
\frac{\partial \mu}{\partial x_i} &= \frac{1}{m} \\
\frac{\partial \sigma}{\partial x_i} &= \frac{1}{m} \cdot \sigma^{-1} \cdot (x_i - \mu)
\end{aligned}
$$

代入可以得到:
$$
\begin{aligned}
\frac{\partial{\mathcal{\hat x_j}}}{\partial{x_i}} &= \left(\delta_{ij} - \frac{\partial{\mu}}{\partial x_i}\right) \cdot \sigma^{-1} - \sigma^{-2} \cdot (x_j - \mu) \cdot \frac{\partial\sigma}{\partial x_i} \\
&= \left(\delta_{ij} - \frac{1}{m}\right) \cdot \sigma^{-1} - \sigma^{-2} \cdot (x_j - \mu) \cdot \frac{1}{m} \cdot \sigma^{-1} \cdot (x_i - \mu) \\
&= \sigma^{-1} \cdot \delta_{ij} - \frac{1}{m} \cdot \sigma^{-1} - \frac{1}{m} \cdot \sigma^{-3} \cdot (x_i - \mu) \cdot (x_j - \mu) \\
&= \sigma^{-1} \cdot \delta_{ij} - \frac{1}{m} \cdot \sigma^{-1} - \frac{1}{m} \cdot \sigma^{-1} \cdot {\hat x_i} \cdot {\hat x_j}
\end{aligned}
$$

最后带入梯度$\frac{\partial{\mathcal{L}}}{\partial{x_i}}$中可以得到:
$$
\begin{aligned}
\frac{\partial{\mathcal{L}}}{\partial{x_i}} &= \sum_j {\frac{\partial{\mathcal{L}}}{\partial{y_j}} \cdot w_j \cdot \frac{\partial{\mathcal{\hat x_j}}}{\partial{x_i}}} \\
&= \sum_j {\frac{\partial{\mathcal{L}}}{\partial{y_j}} \cdot w_j \cdot \left(\sigma^{-1} \cdot \delta_{ij} - \frac{1}{m} \cdot \sigma^{-1} - \frac{1}{m} \cdot \sigma^{-1} \cdot {\hat x_i} \cdot {\hat x_j}\right)} \\
&= \frac{1}{\sigma} \cdot \frac{\partial{\mathcal{L}}}{\partial{y_i}} \cdot w_i - \frac{1}{m \sigma} \cdot \sum_j {\frac{\partial{\mathcal{L}}}{\partial{y_j}} \cdot w_j} - \frac{\hat x_i}{m \sigma} \cdot \sum_j {\frac{\partial{\mathcal{L}}}{\partial{y_j}} \cdot w_j \cdot \hat x_j} \\
&= \frac{1}{\sigma} \cdot \left[\frac{\partial{\mathcal{L}}}{\partial{y_i}} \cdot w_i - \frac{1}{m}\cdot \left( \sum_j {\frac{\partial{\mathcal{L}}}{\partial{y_j}} \cdot w_j} + \hat x_i \cdot \sum_j {\frac{\partial{\mathcal{L}}}{\partial{y_j}} \cdot w_j \cdot \hat x_j}\right)\right]
\end{aligned}
$$

均值和标准差的梯度

均值的梯度为:
$$
\begin{aligned}
\frac{\partial \mu}{\partial x_i} &= \frac{\partial}{\partial x_i} \left(\frac{1}{m} \cdot \sum_j{x_j}\right) \\
&= \frac{1}{m}
\end{aligned}
$$

标准差的计算公式可以写成$\sigma = \left[\mu(x^2) - \mu^2(x)\right]^{\frac{1}{2}}$,所以梯度为:
$$
\begin{aligned}
\frac{\partial \sigma}{\partial x_i} &= \frac{\partial}{\partial x_i} \left[\mu(x^2) - \mu^2(x)\right]^{\frac{1}{2}} \\
&= \frac{1}{2} \cdot \left[\mu(x^2) - \mu^2(x)\right]^{-\frac{1}{2}} \cdot \left(\frac{2}{m} \cdot x_i - \frac{2}{m} \cdot \mu \right) \\
&= \frac{1}{m} \cdot \sigma^{-1} \cdot (x_i - \mu)
\end{aligned}
$$


   转载规则


《手推公式之“层归一化”梯度》 韦阳 采用 知识共享署名 4.0 国际许可协议 进行许可。
 上一篇
历时一年,论文终于被国际顶会接收了 历时一年,论文终于被国际顶会接收了
就在昨天,超算领域的国际顶会SC22放榜了,我们组的论文也被接收了,得分44332(4分制): 论文地址:https://arxiv.org/abs/2110.05722 SC会议是什么?可能很多同学都没听过SC是什么会议,SC会议全称“高
2022-06-16
下一篇 
“交叉熵”反向传播推导 “交叉熵”反向传播推导
交叉熵(CrossEntropy)是常见的损失函数,本文详细推导一下它的梯度,面试大厂或者工程实践中都可能会用到。 前向传播假设分类任务类别数是$V$,隐层输出是$V$维向量$\mathbf{h}$,标准的one-hot向量是$\mathb
2022-05-21
  目录