第16篇:长短期记忆网络(LSTM)——解决RNN“遗忘症”的良方(原理解析)

张开发
2026/4/17 12:14:17 15 分钟阅读

分享文章

第16篇:长短期记忆网络(LSTM)——解决RNN“遗忘症”的良方(原理解析)
文章目录现象引入RNN的“记忆短路”问题提出问题如何让网络拥有“可控记忆”原理剖析LSTM的三道门与一条记忆线第一道门遗忘门Forget Gate第二道门输入门Input Gate与候选值更新细胞状态第三道门输出门Output Gate源码印证透过PyTorch看LSTM实现实际影响为什么LSTM成为里程碑现象引入RNN的“记忆短路”问题几年前我接手一个文本情感分析项目需要模型能理解句子中长距离的依赖关系比如“虽然这部电影的特效非常震撼场景也宏大演员阵容堪称豪华但是由于剧情逻辑的混乱和台词的苍白整体上让我感到非常失望”。用经典的RNN循环神经网络跑了几轮效果总是不理想。模型似乎只记住了最后“感到非常失望”却“忘记”了前面那一大串的“虽然…”导致判断时常出错。这就是RNN著名的“长期依赖”问题也叫梯度消失/爆炸。简单说RNN的记忆像金鱼信息在时间步间传递时梯度会指数级衰减或增长导致它学不会长序列中远距离的关联。当时我就想必须得用LSTM了。提出问题如何让网络拥有“可控记忆”面对RNN的“遗忘症”我们核心要解决两个问题如何长期保存重要信息比如上面例子中“虽然”这个转折词所引导的语义需要穿越很长距离去影响最后的结论。如何选择性记忆与遗忘不是所有信息都值得一直记住。比如“特效震撼”这个正面信息在遇到“但是”后其重要性就应该被降低。LSTMLong Short-Term Memory长短期记忆网络的提出正是为了赋予网络这种“可控的记忆能力”。它不是一个黑盒子其设计思想非常精妙核心在于用“门控”机制来管理一个叫做“细胞状态”的记忆主线。原理剖析LSTM的三道门与一条记忆线你可以把LSTM单元想象成一个信息加工车间其中有一条贯穿始终的传送带叫做细胞状态Cell State记为 C_t。这条传送带是LSTM实现长期记忆的关键它只在少量线性交互下贯穿时间信息在上面流动很容易保持不变。车间的所有操作都是围绕如何向这条传送带上“添加”或“移除”信息而展开的。这些操作由三个结构精巧的“门”来控制每个门都是一个Sigmoid神经网络层和一个点乘操作的组合。Sigmoid层输出0到1之间的值描述“让多少信息通过”0代表“全不让过”1代表“全放行”。第一道门遗忘门Forget Gate作用决定从细胞状态中丢弃哪些信息。这是LSTM的第一步。它查看当前输入x_t和上一个隐藏状态h_{t-1}并为细胞状态C_{t-1}中的每个元素输出一个0到1之间的数。f_t σ(W_f · [h_{t-1}, x_t] b_f)这个f_t向量将直接与上一时刻的细胞状态C_{t-1}相乘。如果f_t的某个位置是0就意味着“完全忘记”旧状态中对应的信息如果是1则意味着“完全保留”。我的理解这是“主动遗忘”机制。比如在读到“但是”时遗忘门就应该学习去降低前面那些正面描述信息在细胞状态中的权重。第二道门输入门Input Gate与候选值作用决定将哪些新信息存入细胞状态。这一步包含两部分输入门i_t一个Sigmoid层决定我们将更新哪些值。i_t σ(W_i · [h_{t-1}, x_t] b_i)候选记忆细胞~C_t一个tanh层创建一个新的候选值向量这些值可能会被加入到细胞状态中。~C_t tanh(W_C · [h_{t-1}, x_t] b_C)接下来我们将这两部分结合来对细胞状态进行更新。更新细胞状态现在我们可以把旧的细胞状态C_{t-1}更新为新的C_t了。C_t f_t * C_{t-1} i_t * ~C_t这个公式是LSTM的核心它分两步f_t * C_{t-1}遗忘掉我们决定要遗忘的部分。i_t * ~C_t添加我们决定要添加的新候选值由输入门筛选过的。通过这种“先忘后加”的线性操作细胞状态C_t实现了信息的可控流转和长期保存。梯度在这里可以稳定地流动有效缓解了消失问题。第三道门输出门Output Gate作用基于细胞状态决定输出什么。首先运行一个Sigmoid层输出门来确定细胞状态的哪些部分将被输出。o_t σ(W_o · [h_{t-1}, x_t] b_o)然后我们将细胞状态通过tanh函数将值压到-1和1之间并将其与输出门的输出相乘得到最终的隐藏状态h_t这个h_t也会被传递到下一个时间步并作为当前时刻的输出。h_t o_t * tanh(C_t)注意h_t和C_t是不同的。C_t是内部记忆主线h_t是对外暴露的、经过过滤的“摘要信息”。源码印证透过PyTorch看LSTM实现理论说得再漂亮不如看一行代码。我们以PyTorch为例看看LSTM单元的核心计算是如何实现的。这能帮你彻底理解上面的公式。importtorchimporttorch.nnasnn# 定义一个单层LSTM单元输入维度10隐藏状态维度20lstm_cellnn.LSTMCell(input_size10,hidden_size20)# 初始化隐藏状态h0和细胞状态c0hxtorch.randn(3,20)# (batch_size, hidden_size)cxtorch.randn(3,20)# (batch_size, hidden_size)# 当前时间步的输入inputtorch.randn(3,10)# (batch_size, input_size)# 前向传播一次对应我们上面讲的所有公式hx_next,cx_nextlstm_cell(input,(hx,cx))# 我们自己手动实现一遍LSTMCell的核心计算加深理解defmanual_lstm_cell(x,hx,cx,weights): x: 当前输入 hx: 上一时刻隐藏状态 cx: 上一时刻细胞状态 weights: 包含所有W和b的字典为简化这里省略拼接和拆解细节 实际PyTorch源码中是一次性计算所有门再拆分的效率更高。 # 1. 将输入和上一隐藏状态拼接combinedtorch.cat((x,hx),dim1)# 2. 一次性计算所有门和候选值实际源码做法# 这对应公式中的 W * [h, x] b输出维度是 4 * hidden_sizegatestorch.mm(combined,weights[weight_ih].T)weights[bias_ih]\ torch.mm(hx,weights[weight_hh].T)weights[bias_hh]# 3. 拆分出输入门(i)、遗忘门(f)、细胞候选值(g)、输出门(o)ingate,forgetgate,cellgate,outgategates.chunk(4,1)# 4. 应用激活函数ingatetorch.sigmoid(ingate)forgetgatetorch.sigmoid(forgetgate)cellgatetorch.tanh(cellgate)outgatetorch.sigmoid(outgate)# 5. 更新细胞状态核心公式cy_nextforgetgate*cxingate*cellgate# 6. 计算输出/下一隐藏状态hy_nextoutgate*torch.tanh(cy_next)returnhy_next,cy_next看manual_lstm_cell函数中的第5步cy_next forgetgate * cx ingate * cellgate正是我们原理部分讲的核心更新公式。PyTorch的官方实现torch.nn._functions.rnn.LSTMCell在底层也是严格按照这个数学定义来的只是用了更高效的矩阵一次运算。实际影响为什么LSTM成为里程碑LSTM的提出1997年是RNN发展史上的一个里程碑其影响深远解决了工程难题在实际应用中如机器翻译、语音识别、时间序列预测等需要处理长序列的任务上LSTM的表现远超传统RNN使其变得真正可用。启发了更多结构LSTM的成功证明了“门控”机制的有效性直接催生了后来更简洁的GRUGated Recurrent Unit以及更复杂的双向LSTM、深度LSTM等变体。奠定了序列建模基础在Transformer崛起之前LSTM及其变体几乎是所有序列建模任务的默认选择是自然语言处理从统计方法走向神经网络方法的关键支柱之一。我的踩坑提示虽然LSTM强大但别把它当银弹。它的计算量比RNN大参数也多。对于不是特别长的序列或者当数据量不足时简单的RNN或GRU可能是更高效的选择。而且自从Transformer出现后在很多任务上基于自注意力的模型在长距离依赖捕捉和能力上已经超越了LSTM。但理解LSTM依然是理解序列建模思想不可或缺的一课。总结一下LSTM通过引入“细胞状态”和“遗忘、输入、输出”三道门精巧地实现了对信息的长期记忆和选择性控制一举攻克了RNN的梯度消失难题。它的设计思想是神经网络结构创新中的一个经典范例。如有问题欢迎评论区交流持续更新中…

更多文章