Attention Mask在Seq-to-Seq生成模型中的核心作用与实现解析

张开发
2026/4/19 18:57:28 15 分钟阅读

分享文章

Attention Mask在Seq-to-Seq生成模型中的核心作用与实现解析
1. Attention Mask在Seq-to-Seq模型中的核心作用第一次用BART做文本生成时我盯着输出结果百思不得其解——为什么模型生成的句子前半段很通顺后半段却开始胡言乱语直到我注意到attention mask的设置问题才恍然大悟。这就像教小孩写作文时如果让他看到后面的参考答案他就永远学不会自主创作了。在Seq-to-Seq架构中attention mask本质上是个信息过滤器。想象你在嘈杂的咖啡厅里专注读书大脑会自动屏蔽周围噪音——这就是人脑的attention mask。Transformer模型通过三种典型mask实现类似功能编码器全可见mask像全景相机每个词能看到整个输入序列。处理[CLS] 今天 天气 真好 [SEP]时天气可以同时关注前后词解码器因果mask像遮住试卷答案的纸板确保生成第t个词时只能看前t-1个词。生成我爱时不允许提前看到中国前缀mask混合前两种模式常见于UniLM。比如处理输入北京 [SEP] 输出是首都时允许首都看到整个输入但看不到后面的[EOS]实际项目中我曾用T5生成商品标题。没加解码器mask时模型会把商品参数直接抄到标题里如生成手机 6.7寸 骁龙8Gen2 3999。加上因果mask后才学会组织自然语言高性能骁龙8Gen2大屏手机仅售3999元。2. 从UniLM看三种语言模型范式UniLM论文就像给attention mask玩法写了本百科全书。去年优化客服问答系统时我对比过这三种模式2.1 单向语言模型Unidirectional LM典型代表GPT系列mask矩阵示例[0, -∞, -∞] [0, 0, -∞] [0, 0, 0]实战坑点做文本续写时右到左(left-to-right)模型会生成好 天气 今天这样的倒装句。解决方案是统一训练和推理的方向2.2 双向语言模型Bidirectional LM典型代表BERTmask矩阵示例[0, 0, 0] [0, 0, 0] [0, 0, 0]特殊技巧在分类任务中我习惯把[CLS]位置的mask设为全0强迫模型通过该token聚合全局信息2.3 序列到序列语言模型Seq-to-Seq LM典型代表BART、T5mask矩阵示例输入3词输出2词[0, 0, 0, -∞, -∞] [0, 0, 0, -∞, -∞] [0, 0, 0, -∞, -∞] [0, 0, 0, 0, -∞] [0, 0, 0, 0, 0]业务场景在电商摘要生成中这种mask让模型在编码阶段看到全部商品描述解码阶段只能看到已生成的部分摘要3. HuggingFace中的mask实现细节打开transformers库的modeling_utils.py你会找到这两个关键函数3.1 _expand_mask函数解析这个函数处理的是编码器mask主要应对变长输入。比如批量处理两个句子你好 [PAD] [PAD]今天 天气 真好对应的原始mask应该是[[1, 0, 0], [1, 1, 1]]经过_expand_mask变换后# 形状变为 [2, 1, 3, 3] [ [[[0, -inf, -inf], # 你只能看到你 [0, 0, -inf], # 好能看到你好 [0, 0, 0]]], # [PAD]能看到全部但后续会被过滤 [[[0, 0, 0], [0, 0, 0], [0, 0, 0]]] ]实际调试时我发现如果忘记把mask转为bool类型会导致某些GPU上出现精度错误。这是个容易踩的坑。3.2 _make_causal_mask函数精要这是解码器的核心保护机制以下面代码为例def _make_causal_mask(input_ids_shape, dtype, past_key_values_length0): bsz, tgt_len input_ids_shape mask torch.full((tgt_len, tgt_len), float(-inf)) mask_cond torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond (mask_cond 1).view(mask.size(-1), 1), 0) return mask生成我爱中国时的mask矩阵[0, -∞, -∞, -∞] [0, 0, -∞, -∞] [0, 0, 0, -∞] [0, 0, 0, 0]在实现对话系统时past_key_values_length参数特别有用。它允许模型在生成第N轮回复时能看到前N-1轮的对话历史。4. 工业级应用中的进阶技巧在日均千万级请求的新闻摘要系统中我们总结了这些实战经验4.1 动态mask优化内存优化对于固定长度输入预计算mask矩阵并缓存。我们的实验显示这能使推理速度提升15%混合精度训练mask矩阵需要与logits保持相同数据类型。使用fp16时要注意-inf会被截断我们改用-1e4替代4.2 特殊场景适配长文本生成当序列超过1024时传统mask会耗尽显存。我们采用块稀疏mask就像这样[[0, -∞, -∞, ..., -∞], [0, 0, -∞, ..., -∞], ..., [0, 0, 0, ..., 0]] # 只保留对角线附近的注意力多模态输入处理图文混合输入时我们设计跨模态mask允许文本关注图像区域但禁止反向关注4.3 调试技巧可视化工具用seaborn绘制mask矩阵一眼就能发现形状错误import seaborn as sns sns.heatmap(mask[0,0].cpu().numpy())梯度检查如果模型不收敛检查mask是否意外阻断了有效梯度传播。我们曾遇到因mask误置导致encoder梯度为零的案例在最近的项目中我们还尝试了可学习maskLearned Attention Mask。让模型自行决定哪些位置应该被屏蔽这在抽象摘要任务中获得了3.2%的ROUGE提升。

更多文章