跳转至

0x01 注意力机制

几种注意力变种

alt text

Scaled Dot-Product Attention 图示

alt text

从下到上的计算顺序如下

  • 变换:输入为 hidden states \(X\),先经过线性变换,将 X 转换为 \(Q、K、V\) 三个矩阵
  • 内积:计算 \(Q\)\(K\) 之间的内积分数,也就是\(QK^\top\)
  • 缩放:将分数除以向量维度的平方根 \(\sqrt{d}\),得到注意力分数
  • 归一化:使用 softmax 函数对注意力权重归一化,\(\text{softmax()} = \frac{\exp(a)}{\sum \exp(a)}\),得到注意力权重
  • 结果:将归一化分数与 value 相乘

有个经常被考察的问题:为什么注意力计算公式中的 \(\text{Attention}(Q,K,V) = \text{softmax}(\frac{QK^\top}{\sqrt{d}})V\) 中要除以 \(\sqrt{d}\)

原始论文的解释:当维度\(d_k\) 比较大时,\(q\)\(k\)的点积会出现比较大的的数值,而大数值经过 softmax 函数,会使分数的分布变得陡峭(大分数趋近于成 1,小分数趋近于 0),从而将 softmax 的函数梯度区平缓,变得很小,让模型难以收敛,加大训练和学习的难度。因此按照 \(\frac{1}{\sqrt{d}}\) 做缩放,使得点积和维度解耦,消除了维度增大导致内积增大的效应,从而维持梯度的稳定性。

总结一下论文解释的的关键要点

  • 如果维度 \(d_k\) 变大,那么 \(q\cdot k^\top\)方差会变大,进而导致不同 \(q、k\) 内积之间的差变大
  • 差值变大,会使得 softmax 函数退化成 argmax,大分数趋近于成 1,小分数趋近于 0
  • 如果只有少数几个值趋近 1,大部分参数反向传播后梯度均趋近 0,出现梯度消失

方差变大

一般认为输入神经网络的向量,服从均值为 0,方差为 1 的正态分布,分析一下 \(q \cdot k^\top\)的方差:

\[ var[q_i \cdot k_i^\top] = var[\sum^{d_k}_{i=1} q_i \times k_i] = \sum^{d_k}_{i=1} var[q_i \times k_i] = \sum^{d_k}_{i=1} var[q_i] \times var[k_i] = \sum^{d_k}_{i=1}1 = d_k \]

由此可见,方差随着向量的维度\(d_k\)发生变化。

softmax 变化

下面再来分析一下 softmax 的在高方差情况在的退化。 softmax 函数中,每个分量的计算如下:

\[ \text{softmax}(x_i) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} \]

将每个元素 \(x_i\) 表示为其中 最大元素 \(x_\text{max} = \text{max}(x)\) 和一个差值 \(\Delta_i\),也就是\(x_i = x_\text{max} - \Delta_i\),从而将 softmax 函数重写为:

\[ \text{softmax} (x_i) = \frac{\exp(x_{\text{max}} - \Delta_i)}{\sum_j \exp(x_{\text{max} }-\Delta_j)} = \frac{\exp( - \Delta_i)}{\sum_j \exp(-\Delta_j)} \]

而,\(x_{\text{max}}\) 趋近于 1,\(\Delta_i\)趋近于0,导致最后的 \(\text{Softmax}(x_i)\) 也接近 0.

因此,在应用 Softmax 之前,我们需要找到一种解决方案,帮助减少这些数字的方差。现在问题已经简化为降低包含初始注意力分数的乘积矩阵的方差,Vanilla Transformer 提出的方案就是缩放矩阵以获得与之前相同的方差。

注意力的优化策略

序列角度优化

原始自注意力机制中,每个token都要和所有token作注意力计算。但是Transformer中的注意力分布其实是不均匀的:

  • 起始层的注意力分布大致均匀
  • 中间层的注意力模式变得更加复杂,大部分概率集中在初始 token(注意力汇聚)和最近的/最后 tokem(近期偏见)

因此,部分对注意力改进在稀疏计算,每个 token 只与一部分 token 做计算,关键在于如何选择注意力的计算方式,一些典型的类别如下:

  • Atrous self-attention(空洞自注意力):每个 token 等间隔的和其他 token 做注意力计算
  • Local self-attention(局部注意力):类似于 n-gram,每个 token 与附近的几个 token 做注意力计算
  • Sparse attention(稀疏注意力):结合了空洞注意力和局部注意力机制,根据规则(或者动态决定),每个 token 即有局部感受野,也有全局感受野

多头角度优化

具体包括MHA、MQA、GQA、MLA

软硬件结合角度优化

硬件层面上,现在已在使用的 HBM(高速带宽内存)提高读取速度。

软件层面上,Flash Attention,Paged Attention 等