📜  理解 LSTM 网络(1)

📅  最后修改于: 2023-12-03 15:40:51.192000             🧑  作者: Mango

理解 LSTM 网络

LSTM(Long-Short Term Memory)是一种特殊的循环神经网络(Recurrent Neural Network,简称 RNN)。它被广泛用于处理序列数据,包括文本、语音、视频等。

相比于普通的 RNN,LSTM 添加了三个特殊的门控节点,可以有效地解决梯度消失和梯度爆炸的问题,从而更好地实现长期依赖关系的建模。

LSTM 的结构

LSTM 的结构非常复杂,但可以简单地分为三部分:遗忘门(forget gate)、输入门(input gate)和输出门(output gate)。

LSTM

遗忘门

假设当前时刻为 t,LSTM 网络需要记忆前 t-1 个时刻的信息,从而判断当前时刻的信息是否有用。遗忘门的作用就是实现信息的遗忘,只保留重要的信息。

具体来说,遗忘门会根据当前时刻的输入和前一个时刻的状态,输出一个数值 0 到 1 之间的实数,表示需要保留的信息量。数值越大,保留的信息就越多;数值越小,保留的信息就越少。

遗忘门的公式如下:

$$f_t = \sigma(W_f \cdot [h_{t-1},x_t] + b_f)$$

其中,$h_{t-1}$ 表示前一个时刻的状态,$x_t$ 表示当前时刻的输入,$W_f$ 和 $b_f$ 是遗忘门的权重参数,$\sigma$ 是 sigmoid 函数。

输入门

输入门的作用是根据当前时刻的输入和前一个时刻的状态,更新当前时刻的记忆 cell。

具体来说,输入门会根据当前时刻的输入和前一个时刻的状态,输出一个数值 0 到 1 之间的实数,表示需要更新 cell 的程度。数值越大,更新的程度就越大;数值越小,更新的程度就越小。

另外,输入门还会根据当前时刻的输入,生成一个候选值 $C_t$,作为更新后的 cell 的新值。

输入门的公式如下:

$$i_t = \sigma(W_i \cdot [h_{t-1},x_t] + b_i)$$

$$C_t = \text{tanh}(W_C \cdot [h_{t-1},x_t] + b_C)$$

其中,$W_i$、$W_C$ 和 $b_i$、$b_C$ 分别是输入门的权重参数和偏置参数。

输出门

输出门的作用是根据当前时刻的输入和前一个时刻的状态,输出当前时刻的预测值。

具体来说,输出门会根据当前时刻的输入和前一个时刻的状态,输出一个数值 0 到 1 之间的实数,表示需要输出的程度。数值越大,输出的程度就越大;数值越小,输出的程度就越小。

另外,输出门还会根据更新后的 cell,生成当前时刻的预测值 $h_t$。

输出门的公式如下:

$$o_t = \sigma(W_o \cdot [h_{t-1},x_t] + b_o)$$

$$h_t = o_t \cdot \text{tanh}(C_t)$$

其中,$W_o$ 和 $b_o$ 是输出门的权重参数和偏置参数。

LSTM 的实现

在 TensorFlow 中,可以使用 tf.keras.layers.LSTM 实现 LSTM 网络。例如,以下代码实现了一个单层的 LSTM 网络:

import tensorflow as tf

model = tf.keras.Sequential()
model.add(tf.keras.layers.LSTM(128, input_shape=(length, dim)))
model.add(tf.keras.layers.Dense(num_classes))

其中,LSTM(128, input_shape=(length, dim)) 表示添加一个有 128 个神经元的 LSTM 层,输入数据为 (length, dim) 的张量,Dense(num_classes) 表示添加一个 num_classes 个输出的全连接层。

参考资料
  • Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory. Neural computation, 9(8), 1735-1780.
  • https://www.tensorflow.org/api_docs/python/tf/keras/layers/LSTM