Skip to content

Lesson-6 循环神经网络

循环神经网络(RNN)

循环神经网络是一类用于处理序列数据的神经网络结构,擅长捕获时间序列或序列数据中的动态信息。

RNN 的流程如下:

  • 首先,对输入进行 word embedding:\(e^{(t)} = Ex^{(t)}\)
  • 然后,每一时刻的状态只和之前时刻的状态和这一时刻的输入有关:\(h^{(t)} = \sigma(W_h h^{(t-1)} + W_ee^{(t)}+b_1)\),其中\(\sigma\)为激活函数
  • 最后,每一步的输出只考虑当前的隐藏状态:\(\hat{y^{(t)}}=softmax(u^T h^{(t)}+b_2)\)

alt text

RNN 的优点:

  • 可以处理任意长度的输入
  • 模型并不会随着输入长度的变长而变大
  • 计算 step-t 的时候可以考虑前面输入的信息

缺点:

  • 每一步都需要计算且无法并行计算,循环计算很慢
  • 计算 step-t 的时候实际上很难考虑到前面很多步的输入的信息

训练 RNN 的思路如下:

每一步都有一个损失,使用预测概率分布\(\hat{y}^{(t)}\)和真实概率分布\({y}^{(t)}\)的交叉熵来定义:

\[J^{(t)}(\theta) = - \sum_{w \in V} y_w^{(t)}\log \hat{y_w}^{(t)} \]

而因为真实分布 \(y_w^{(t)}\) 是独热编码,因此上式可以化简为:\(- \log \hat{y}_{x_{t+1}}^{(t)}\)

根据每一步的损失,我们可以定义出整个训练集的损失,等于每一步损失的平均:

\[J^{(t)}(\theta) = \frac{1}{T}\sum_{i=1}^TJ^{(t)}(\theta)= -\frac{1}{T}\sum_{i=1}^T \log \hat{y}_{x_{t+1}}^{(t)}\]

alt text

我们发现:\(W_h\)参与了每个隐藏状态之间的转换,因此:\(J^{(t)}(\theta)\)\(W_h\) 的倒数为:

\[\frac{\partial J^{(t)}}{\partial W_h} = \sum_{i=1}^t \frac{\partial J^{(i)}}{\partial W_h}\]
梯度消失和梯度爆炸

(1)梯度消失:

alt text

不难发现:对于\(t_1\)来说,较远的梯度信号\(J^{(4)}(θ)\)比较近的梯度信号\(J^{(2)}(θ)\)小得多,所以模型的参数更新取决于较近的梯度信号。

但是,RNN 的场景中往往需要考虑建立长距离依赖。

解决方法:使用 ReLU、LeakyReLU、BatchNorm 等等方法

(2)梯度爆炸:

alt text

每一层的梯度假设都很大,由于链式法则,下面的层得到反向传播的梯度就会很大,这会导致:会导致梯度下降的step过大,出现bad updates。即因为步子迈的太大而找不到最小值点。

解决方法:梯度裁剪,很简单的思路:如果梯度的范数大于某个阈值,则在应用SGD更新之前将其缩小(按照缩小比例): alt text

LSTM 架构

长短期记忆网络(LSTM,Long Short-Term Memory)是一种特殊的RNN结构,旨在解决标准RNN在处理长序列时面临的梯度消失和梯度爆炸问题。LSTM通过引入门控机制(即遗忘门、输入门和输出门),能够在更长的时间序列中保持信息。

alt text

下面我们逐个介绍 LSTM 的组件:

  • cell state:表示LSTM本身的状态信息,一步一步往下传

alt text

  • 遗忘门:\(f_t = \sigma(W[h_{t-1},x_t]+b_f)\),其中\(\sigma\)是激活函数 Sigmoid,因此:\(f_t\)的取值在0和1之间

alt text

  • 输入门:首先,生成候选单元状态(可以理解为等待处理的\(c_t\)):\(\tilde{c}_t = \tanh(W_c [h_{t-1}, x_t] + b_c)\);接下来,再生成输入门的激活值:\(i_t = \sigma(W_i [h_{t-1}, x_t] + b_i)\),那么,\(i_t\)也是一个在0到1之间到值。然后,我们对记忆单元进行更新:\(c_t = f_t \cdot c_{t-1} + i_t \cdot \tilde{c}_t\),通过\(f_t \cdot c_{t-1}\)部分保留过去的记忆,通过\(i_t \cdot \tilde{c}_t\)引入新的记忆

alt text

  • 输出门:输出门控制输出的隐藏状态,这部分公式很简单:先获得激活值(可以理解为输出多少 \(c_t\)):\(o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o)\),在得到当前的隐藏状态:\(h_t = o_t \cdot \tanh(c_t)\)

alt text

LSTM 的优点:

  • 门控机制和长期记忆状态使得梯度可以在多个时间步内稳定地传递
  • 通过遗忘门和记忆单元,可以保存和选择性丢弃时间序列中长期相关的信息
  • 适用于序列数据

缺点:

  • 计算复杂度高
  • 训练时间长