Lesson-9 Transformer
我们发现:RNN 在计算上存在一个问题:即其每次生成都需要计算 N 步,并且这 N 步是按顺序依次计算的,不可以并行
Self-Attention
鉴于 RNN 的这个缺点,我们提出 Self-Attention:每个词都要 Attend 到所有词,并且距离都认定为1,这个过程是可以并行的。
具体来说,Self-Attention 的流程是这样的:
- 首先,我们有一系列输入:\(x_1,x_2,...,x_n\)
- 对于每个输入,我们都计算出三个向量:\(q_i = Q x_i\),\(k_i = K x_i\),\(v_i = V x_i\)
- 对于任意两个输入向量之间,我们都计算它们的 Attention 分数:\(e_ij = q_i^T k_j\),然后,对于每个\(i\),我们根据 Attention 分数来计算对应的权重:\(\alpha_i = \textbf{Softmax}(e_i)\)
- 根据得到的权重,我们得到输出:\(o_i = \sum _j \alpha_{ij}v_j\)
Self-Attention 的问题
(1)对位置不敏感:我们不难发现,Self-Attention 中并没有引入任何与位置有关的内容,任意两个位置之间都默认距离为1。
解决方法:位置编码:把词的位置表示为向量
常见的位置编码有:
(i)正弦/余弦函数的位置表示法:
优点为:
- 周期性的相对位置表示,可以外推
缺点:
- 不可以学习
(ii)绝对位置表示法:
顾名思义,每个位置的位置编码都是可以学习的,即学习出一个位置矩阵
优点:
- 每个位置编码都可以学习
缺点:
- 难以外推
(2)我们发现,Self-attention 的结果实际上就是把每个输入生成的value进行线性求和,线性有时候是不足以解决问题的。因此,我们在每个输出后面都添加一个 Feed-Forward:
前馈层包含至少一个非线性激活函数,它能够引入更多的表达能力和复杂的特征组合。
(3)训练的时候,模型是不能看见未来的词的,这是不可接受的,因为生成一个词时,模型应该只能基于当前和过去的词进行预测,而不能看到将来。解决方法:Masking:在训练语言生成任务(如机器翻译、文本生成)时,通常将一个词后面的所有词的attention score遮盖,确保模型在生成当前词时只能看到当前和之前的词:
Transformer
总体架构如下图:
从 Encoder 开始说起:首先,输入经过位置编码后进入一个 Block 中,首先,先进行 Multihead Attention,接着,进行 Add 和 Norm,得到的输出再被送入一个前馈层中,再进行一次 Add 和 Norm,这就得到了一个 Block 的输出,会存在多个 Block。
接着就是 Decoder,步骤与 Encoder 基本一致,但也存在区别:(1)第一次 Multihead Attention 后 Add 和 Norm 后会进行与 Encoder 的信息交流:Cross-Attention;(2)Decoder 的输入为自己先前的输出。
得到Decoder的输出后,再送入一个线性层中,最后送入 Softmax 中,得到最终的输出。
下面,我们介绍一下核心组件:
(1)Transformer 中采用的是 Scaled Dot-product Attention:这是因为:随着长度增加,dot-product会变大(分子),所以除以\(d\)来消减这个值(和\(d\)解耦):
(2)多头注意力:多头自注意力通过并行计算多个自注意力头,能够从不同的子空间学习多种语义和上下文信息,从而增强模型的表达能力。
具体公式如下:
其中\(W_o\)是输出的线性变换矩阵,将拼接后的结果映射回原始维度。
(3)Residual connections:残差连接的作用是让信息在网络中更容易流动,避免梯度消失问题。这是通过将输入直接添加到输出上来实现的:
它的好处:
- 缓解梯度消失问题:残差连接使得梯度能够直接从网络的较深层传播到较浅层,有助于深层网络的训练
- 促进信息流动:信息可以不经过非线性变换直接传递,从而使得模型能够学习到更有效的特征
- 改善训练稳定性:残差连接使得网络更容易训练,尤其是当模型较深时,能够加速收敛并避免过拟合
(4)Layer Norm:规范化(Normalization)层通常是在残差连接之后进行的,以使得每层的输出保持一致的尺度,避免梯度爆炸或消失。Transformer 中使用的是Layer Normalization,其公式为:
其中\(\mu\)和\(\sigma\)是输入的均值和方差,\(\epsilon\)是一个防止除了0的很小的数,\(\gamma\)和\(\beta\)都是可以学习的参数
本页面最近更新:,更新历史
发现错误?想一起完善? 在 GitHub 上编辑此页!
本页面贡献者:OI-wiki
本页面的全部内容在 协议之条款下提供,附加条款亦可能应用