动手学深度学习——GRU代码

张开发
2026/4/14 18:13:10 15 分钟阅读

分享文章

动手学深度学习——GRU代码
1. 前言上一篇我们已经从原理上认识了GRU门控循环单元它是对基础 RNN 的改进它引入了门控机制它通过更新门和重置门来控制信息流它更擅长处理长期依赖问题但是只理解公式还不够。和前面 RNN 一样真正把 GRU 学扎实最好的方式还是把公式一步一步写成代码。这一节的任务就是把 GRU 真正落到实现层面。你会看到GRU 比 RNN 多了哪些参数更新门、重置门在代码里怎么写候选隐藏状态如何计算最终隐藏状态如何更新简洁实现和从零实现分别怎么对应这一篇本质上就是把“门控记忆”变成可运行的程序。2. GRU 代码实现要解决什么如果把这一节拆开看核心其实就 4 件事2.1 初始化更多参数相比基础 RNNGRU 不再只有一组隐藏状态更新参数而是需要分别为更新门重置门候选隐藏状态各自准备参数。2.2 写新的状态更新公式也就是把上一篇的四条核心公式真正变成代码。2.3 保持语言模型训练接口一致GRU 虽然内部更复杂但对外仍然要能接输入序列初始状态输出预测最终状态2.4 对照简洁实现看 PyTorch 里的nn.GRU到底帮我们封装了哪些内容。3. 先回顾 GRU 的核心公式写代码前先把最关键的四条公式再捋清楚。更新门Z_t σ(X_t W_xz H_{t-1} W_hz b_z)重置门R_t σ(X_t W_xr H_{t-1} W_hr b_r)候选隐藏状态H_t_tilde tanh(X_t W_xh (R_t ⊙ H_{t-1}) W_hh b_h)最终隐藏状态H_t Z_t ⊙ H_{t-1} (1 - Z_t) ⊙ H_t_tilde其中σ是 sigmoid⊙是按元素乘法你可以看到GRU 和基础 RNN 最大的不同就在于更新隐藏状态之前先算门。4. GRU 从零实现先初始化参数基础 RNN 只需要一套隐藏更新参数。而 GRU 至少要准备三套。常见写法如下def get_params(vocab_size, num_hiddens, device): num_inputs num_outputs vocab_size def normal(shape): return torch.randn(sizeshape, devicedevice) * 0.01 def three(): return (normal((num_inputs, num_hiddens)), normal((num_hiddens, num_hiddens)), torch.zeros(num_hiddens, devicedevice)) W_xz, W_hz, b_z three() # 更新门 W_xr, W_hr, b_r three() # 重置门 W_xh, W_hh, b_h three() # 候选隐藏状态 W_hq normal((num_hiddens, num_outputs)) b_q torch.zeros(num_outputs, devicedevice) params [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q] for param in params: param.requires_grad_(True) return params这段代码就是 GRU 从零实现的起点。5. 为什么这里有三组“输入-隐藏-偏置”参数因为 GRU 要分别计算三类东西第一组更新门参数W_xz, W_hz, b_z它们负责控制“旧状态保留多少”。第二组重置门参数W_xr, W_hr, b_r它们负责控制“旧状态在候选状态里参与多少”。第三组候选隐藏状态参数W_xh, W_hh, b_h它们负责生成新的候选状态。所以GRU 比 RNN 参数更多不是因为它“乱加复杂度”而是因为它确实要分别处理三种不同功能。6. 隐藏状态初始化和 RNN 一样吗基本一样。因为 GRU 最终对外仍然只有一个隐藏状态H_t不像 LSTM 还会多一个单独记忆单元。所以状态初始化通常仍然写成def init_gru_state(batch_size, num_hiddens, device): return (torch.zeros((batch_size, num_hiddens), devicedevice), )也就是说每个样本一份隐藏状态初始时全零返回成元组形式方便接口统一7. GRU 的前向传播是这一节最核心的代码常见从零实现写法如下def gru(inputs, state, params): W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q params H, state outputs [] for X in inputs: Z torch.sigmoid(torch.mm(X, W_xz) torch.mm(H, W_hz) b_z) R torch.sigmoid(torch.mm(X, W_xr) torch.mm(H, W_hr) b_r) H_tilde torch.tanh(torch.mm(X, W_xh) torch.mm(R * H, W_hh) b_h) H Z * H (1 - Z) * H_tilde Y torch.mm(H, W_hq) b_q outputs.append(Y) return torch.cat(outputs, dim0), (H,)如果你把这段代码真正看明白GRU 就已经学通了大半。8. 更新门这两行代码怎么理解先看Z torch.sigmoid(torch.mm(X, W_xz) torch.mm(H, W_hz) b_z)它对应的就是更新门公式Z_t σ(X_t W_xz H_{t-1} W_hz b_z)含义是当前输入X和上一隐藏状态H共同决定当前时间步“该保留多少旧状态”Z的形状通常是(batch_size, num_hiddens)也就是说对每个样本、每个隐藏单元都有一个门控值。这意味着GRU 不是粗暴地对整个状态“一刀切”而是按隐藏单元逐个控制。9. 重置门代码怎么理解再看R torch.sigmoid(torch.mm(X, W_xr) torch.mm(H, W_hr) b_r)它对应R_t σ(X_t W_xr H_{t-1} W_hr b_r)它决定的是在计算候选隐藏状态时旧隐藏状态该参与多少。你可以把它理解成一种“历史清洗器”R大说明历史信息还很重要R小说明历史信息该弱化一些10. 候选隐藏状态代码怎么理解这是 GRU 最关键的一步之一H_tilde torch.tanh(torch.mm(X, W_xh) torch.mm(R * H, W_hh) b_h)它对应H_t_tilde tanh(X_t W_xh (R_t ⊙ H_{t-1}) W_hh b_h)要点在于不是直接用H而是先做R * H这表示旧隐藏状态先被重置门筛一遍再参与候选状态生成。这就是 GRU 相比基础 RNN 非常关键的精细化控制。11. 最终隐藏状态更新代码怎么理解看这一句H Z * H (1 - Z) * H_tilde这就是H_t Z_t ⊙ H_{t-1} (1 - Z_t) ⊙ H_t_tilde直观上它就是在做旧状态和新候选状态的加权平均如果Z大则更偏向旧状态Z小则更偏向新候选状态所以最终隐藏状态不是“全用旧的”也不是“全用新的”而是模型自己学出来的动态折中。这正是 GRU 最大的魅力所在。12. 为什么说 GRU 的记忆更可控从这几行代码里你就能直接看出来基础 RNN 的状态更新基本是一条路输入来历史来一起过tanh更新完成而 GRU 明显多了控制环节先决定历史该参与多少再生成候选状态再决定最终保留多少旧状态所以 GRU 本质上是一种可学习的信息流控制机制而不是单纯“算出一个新状态”而已。13. 输出层为什么和 RNN 一样注意这一句Y torch.mm(H, W_hq) b_q这和基础 RNN 完全一样。为什么因为 GRU 改进的是隐藏状态的内部更新机制而不是语言模型最终输出的形式。对于字符级语言模型来说最后仍然是当前隐藏状态映射到词表空间得到对每个字符的打分所以输出层部分不需要大改。14. 从零实现的 GRU 如何封装成模型类和上一节 RNN 一样通常也会封装一个“手写模型容器”。例如继续用之前那种思路net d2l.RNNModelScratch(vocab_size, num_hiddens, device, get_params, init_gru_state, gru)注意这非常漂亮。你会发现模型容器类几乎不用改只需要把参数初始化函数状态初始化函数前向传播函数换成 GRU 版本整个模型就从 RNN 变成了 GRU。这说明什么说明 GRU 和 RNN 在“接口层面”是一致的真正变化的是内部单元。15. 简洁实现PyTorch 里的nn.GRU从零实现看懂以后简洁实现就很好接受了。PyTorch 已经封装好了nn.GRU基本用法和nn.RNN非常相似。例如gru_layer nn.GRU(input_sizevocab_size, hidden_sizenum_hiddens)这里input_size仍然是每个时间步输入向量维度hidden_size仍然是隐藏状态维度然后前向传播接口也很像Y, state gru_layer(X, state)输出逻辑和nn.RNN一样Y所有时间步输出state最后隐藏状态16. 为什么nn.GRU用法和nn.RNN这么像因为从更高层视角看它们都属于循环序列建模模块它们对外接口基本一致输入序列初始状态输出序列最终状态只不过内部单元计算方式不同nn.RNN内部是基础循环单元。nn.GRU内部是门控循环单元。所以你可以把 GRU 看成是RNN API 体系下的一个更强单元版本这对工程使用非常友好。17. 简洁实现里的模型封装怎么写和简洁版 RNN 一样通常可以继续复用同样的模型外壳GRU 层负责序列递推线性层负责把隐藏状态映射到词表空间也就是说模型代码结构上几乎不用大动只要把rnn_layer nn.RNN(...)换成gru_layer nn.GRU(...)就能得到 GRU 版本模型。这说明GRU 的升级主要体现在循环单元内部不在外围框架。18. GRU 代码和 RNN 代码最本质的区别在哪里这是这一节最该点破的地方。RNN 从零实现核心只有一条隐藏状态更新公式。GRU 从零实现核心变成了先算更新门再算重置门再算候选状态再融合新旧状态所以如果你问GRU 相比 RNN代码上本质多了什么答案就是多了门控变量和基于门控的状态融合机制其他外部流程例如one-hot 输入初始状态输出层文本生成其实都还是同一套套路。19. GRU 训练时需要额外特殊处理吗整体训练流程和 RNN 基本一致。仍然是输入 token 序列输出对下一个 token 的预测用交叉熵损失反向传播梯度裁剪参数更新所以从训练框架角度GRU 并没有引入额外特别陌生的流程。变化的是内部状态更新更智能了。这也是为什么在工程里GRU 往往可以很平滑地替代基础 RNN。20. 为什么 GRU 常常比基础 RNN 更实用从代码层面你已经能看出来原因第一它可以显式保留旧状态更新门允许旧信息直接穿过去。第二它可以有选择地忽略部分历史重置门让模型不会被无关旧信息拖累。第三它改进了梯度传播路径虽然不能彻底消除所有问题但比基础 RNN 更容易训练稳定。所以很多时候GRU 是一种在复杂度和性能之间比较均衡的循环结构21. 这一节最该掌握什么如果从学习重点来看这一节最关键的是下面几件事。21.1 看懂参数初始化比 RNN 多在哪知道为什么 GRU 至少需要三组核心参数。21.2 看懂四条核心公式如何一一落到代码尤其是ZRH_tildeH之间的关系。21.3 理解最终状态更新是“新旧信息加权融合”这是 GRU 相比 RNN 的本质增强。21.4 知道nn.GRU和nn.RNN的接口非常接近方便后面工程应用。21.5 明白 GRU 的改进点主要发生在单元内部外围训练逻辑其实变化不大。22. 本节总结这一节我们学习了 GRU 的代码实现核心内容可以总结为以下几点。22.1 GRU 从零实现比 RNN 多了门控参数主要包括更新门参数重置门参数候选状态参数22.2 GRU 前向传播的关键是先算门再更新状态这让信息流更加可控。22.3 最终隐藏状态是旧状态和候选状态的门控融合而不是像基础 RNN 那样一次性混合更新。22.4 简洁实现里nn.GRU和nn.RNN用法高度相似只不过内部计算更强。22.5 GRU 是基础 RNN 向更强序列建模迈出的重要一步也是后面 LSTM 的前置基础。23. 学习感悟这一节很有意思因为你会第一次非常明显地感受到一个模型的提升不一定来自“推翻重来”也可能来自对信息流路径做精细控制。GRU 其实没有把循环神经网络完全重写它只是问了两个更聪明的问题以前的信息还值不值得保留新的信息该不该立刻写进去就是这两个问题让它比基础 RNN 更能“记事”。从这个角度看GRU 的优雅之处不在于公式复杂而在于它让记忆管理第一次变得真正有策略。

更多文章