为什么需要注意力机制

Transformer出现之前, 对于NLP任务的主力模型是:

RNN

LSTM/GRU

这些模型都承担着一个艰难任务:

在一个长序列中,理解词与词之间的关系,尤其是远距离的关系。

但它们有一个共同的致命限制:

  • 不能让每个词看到全局
  • 信息被压得太狠,丢失严重
  • 训练难以并行,效率低下

注意力机制正是为了解决这些根本性问题而诞生的。

为什么RNN/LSTM处理不好全局依赖

单向传递信息

RNN/LSTM 的信息流是链式的:

1
x1 → x2 → x3 → x4 → ...

这意味着:

  • x5 想知道 x1 的信息,要经过 x2 → x3 → x4 的层层传递
  • 一旦中间丢失,就再也补不回来
  • 越远的关系,越容易消失

所以像这种句子对它们非常困难:

“我昨天看了一部电影,电影主角是….剧情非常的跌宕起伏,….它非常好看。”

“它”“电影” 的距离太远,RNN 容易“忘”。

梯度消失/梯度爆炸

链式结构导致反向传播也需要一层层回传。

传统 RNN 的隐藏状态更新公式:

反向传播时梯度会不断乘以权重矩阵 WWW 和 tanh 的导数

不难发现, 梯度的计算呈连乘形式

当序列很长时:

  • 梯度不断变小 → 消失
  • 或不断变大 → 爆炸

这让模型学不到长期关系。

当然,LSTM/GRU 在梯度消失/爆炸问题已经做了优化, 通过门控机制控制信息流, 使一部分状态以近似线性方式在序列中传播,从而避免梯度在时间维度上的指数级衰减或爆炸,解决传统 RNN 的长期依赖问题

串行计算

RNN系列模型的特点:

当前时间步必须等上一时间步走完才能开始

因此, 训练速度很慢

而 Attention 可以并行计算, 一次矩阵乘法全部计算出来, 并行计算, 充分利用GPU

为什么不用CNN

CNN 虽然可以并行,但存在一个无法突破的问题:

CNN 的感受野是局部的。

想让 CNN 跨越长距离依赖,需要不断叠加卷积层扩展感受野。

  • 5层感受野可能只有几十个词
  • 要覆盖整个句子需要几十层甚至更多

训练难度非常大。

隐藏状态是对信息的强压缩

RNN 把之前所有词的信息压进一个固定维度的张量:

1
h_t = f(x_t, h_{t-1})

无论你前面说了 100 个词,还是 1000 个词,最终信息都必须塞进一个固定长度的 h。

上图是seq2seq架构处理翻译任务, seq2seq架构包括三部分, encoder(编码器), decoder(解码器), 中间语义张量c

对于图中的案例, 编码器首先处理中文输入 “欢迎 来 北京”, 通过GRU模型获得每个时间步的输出张量, 最后将它们拼接成一个中间语义张量c, 接着解码器将使用这个中间语义张量c以及每一个时间步的隐层张量, 逐个生成对应的翻译语言

decoder只能依赖固定长度的中间语义张量c

也就是说, 无论句子多长, 多复杂, 都要压缩成一个固定维度张量

结果就是:

  • 信息丢失严重
  • 长序列问题效果差

注意力机制如何解决这些痛点

每个词直接看到所有词

在注意力里,一个词无需通过第 2、3、4、5 个词才能知道第 1 个词的信息。

它可以直接算:

1
相关性 = Query(当前词) · Key(所有词)

这是 完全平等的全局视野

不存在梯度消失问题

因为注意力不依赖链式结构,反向传播只会经过几层矩阵运算,梯度非常稳定。

完全可并行化(压倒 RNN 的关键优势)

RNN 每一步都依赖上一部 ⇒ 只能串行
Attention 所有词之间的关系都可同时计算 ⇒ 一次矩阵乘法

速度差几十倍甚至上百倍。

这是 Transformer 能训练巨型模型的根本原因。

注意力机制如何计算

这里简要概括, 详细过程推荐知乎猛猿的一篇文章, 这里给出链接

Transformer 使用的是缩放点积注意力(Scaled Dot-Prodcut Attention)

对于每个token, 产生三个矩阵 Query, Key, Value(下文简称 Q, K, V)

  • Q(Query):我想找什么?
  • K(Key):我有哪些信息?
  • V(Value):对应的信息内容

(1) Q 与 K 做点积:计算“匹配程度”

1
score = Q • Kᵀ

相似度越高 → 代表 Query 认为这个 Key 更相关。

(2) 缩放(除以 √d)防止数值过大

1
score = score / √d

(3) Softmax → 转成概率分布(注意力权重)

1
attention_weights = softmax(score)

(4) 用权重加权求和 Value,得到最终输出

1
output = attention_weights × V

自注意力机制

对于一般情况下, Query来自decoder, Key, Value来自encoder

自注意力则是一种特殊情况, Query, Key, Value均来自同一个序列, 即