别再被mask搞晕了!用Pytorch的nn.MultiheadAttention手把手带你过一遍Self-Attention(附代码)

张开发
2026/4/19 1:25:29 15 分钟阅读

分享文章

别再被mask搞晕了!用Pytorch的nn.MultiheadAttention手把手带你过一遍Self-Attention(附代码)
解密PyTorch中的注意力掩码从原理到实战的完整指南在自然语言处理领域Transformer架构已经成为处理序列数据的标配工具。而作为其核心组件自注意力机制(Self-Attention)的理解和正确使用尤其是其中的掩码(mask)机制往往是初学者进阶路上的第一道门槛。本文将带你深入剖析PyTorch中nn.MultiheadAttention模块的掩码工作原理通过代码实例演示如何避免常见陷阱。1. 为什么我们需要掩码想象你正在处理一批英文句子每个句子的长度各不相同。为了批量处理我们通常会用特殊标记(如pad)将短句子补齐到相同长度。但问题来了——在计算注意力权重时这些填充标记不应该参与计算否则会干扰模型对实际语义的理解。这就是key_padding_mask的用武之地。另一个场景是序列生成任务。当模型预测当前位置的单词时它理论上只能看到已经生成的左侧内容而不能偷看未来的单词。这种防止信息泄露的机制就是attn_mask(常称为因果掩码)的核心作用。提示两种掩码虽然都叫mask但解决的问题完全不同。前者处理外部强制的padding后者控制模型内部的注意力范围。让我们看一个直观的例子。假设我们有以下两个句子(已转换为token ID)# 原始句子(实际长度分别为3和4) sent1 [10, 20, 30] sent2 [40, 50, 60, 70] # 经过padding后的batch(统一长度为4) batch torch.tensor([ [10, 20, 30, 0], # 0是padding标记 [40, 50, 60, 70] ])2. 两种掩码的对比与实现2.1 key_padding_mask处理变长序列key_padding_mask是一个布尔型张量形状为(N, L)其中N是batch大小L是序列长度。它的设计哲学很简单True表示需要被掩蔽的位置。# 创建key_padding_mask mask torch.tensor([ [False, False, False, True], # 第一个句子的第4位置是padding [False, False, False, False] # 第二个句子无padding ])在实际应用中我们通常通过以下方式生成def create_padding_mask(batch, pad_token0): return (batch pad_token) # 找出所有padding位置2.2 attn_mask构建因果注意力attn_mask则更为复杂它是一个(L, L)的矩阵用于控制每个查询位置可以访问哪些键位置。在自回归任务(如文本生成)中我们需要一个下三角矩阵[[0, -inf, -inf, -inf], [0, 0, -inf, -inf], [0, 0, 0, -inf], [0, 0, 0, 0]]PyTorch中生成方法def create_causal_mask(size): return torch.triu(torch.full((size, size), float(-inf)), diagonal1)2.3 关键区别总结特性key_padding_maskattn_mask形状(N, L)(L, L)值类型布尔值浮点数(通常为0或-inf)主要用途屏蔽padding位置控制注意力范围应用阶段影响注意力得分计算影响注意力得分计算是否依赖输入内容是(基于padding位置)否(通常固定模式)3. 完整代码实战文本分类中的掩码应用让我们通过一个完整的文本分类示例演示如何正确使用两种掩码。import torch import torch.nn as nn class TextClassifier(nn.Module): def __init__(self, vocab_size, embed_dim, num_heads, num_classes): super().__init__() self.embedding nn.Embedding(vocab_size, embed_dim) self.attention nn.MultiheadAttention(embed_dim, num_heads) self.fc nn.Linear(embed_dim, num_classes) def forward(self, x): # x形状: (N, L) padding_mask (x 0) # 假设0是padding标记 # 嵌入层 embeddings self.embedding(x) # (N, L, E) embeddings embeddings.transpose(0, 1) # (L, N, E) # 自注意力 attn_output, _ self.attention( queryembeddings, keyembeddings, valueembeddings, key_padding_maskpadding_mask ) # 取第一个token作为分类特征 cls_output attn_output[0] # (N, E) return self.fc(cls_output)4. 常见陷阱与调试技巧4.1 形状不匹配错误最常见的错误是掩码形状不正确。记住key_padding_mask必须与key的序列维度匹配attn_mask必须是方阵边长等于序列长度4.2 掩码值混淆初学者常犯的错误是混淆两种掩码的值类型key_padding_mask用True表示屏蔽attn_mask用-inf表示屏蔽4.3 训练-推理不一致在文本生成任务中训练时通常使用全序列的因果掩码而推理时需要逐步生成。确保你的掩码逻辑在两种模式下一致# 训练模式 def train_step(batch): causal_mask create_causal_mask(batch.size(1)) output model(batch, attn_maskcausal_mask) # ...计算损失等 # 推理模式 def generate(model, prompt, max_len): for _ in range(max_len): # 每次只对当前序列生成掩码 mask create_causal_mask(len(prompt)) next_token model(prompt, attn_maskmask) prompt torch.cat([prompt, next_token]) return prompt4.4 性能优化当处理长序列时掩码操作可能成为性能瓶颈。考虑以下优化预先计算对于固定模式的attn_mask(如因果掩码)可以预先计算并缓存稀疏矩阵对于非常大的掩码使用稀疏张量表示内置优化PyTorch 1.8的nn.MultiheadAttention已经针对常见掩码模式进行了优化5. 高级应用自定义掩码模式除了标准的padding和因果掩码我们还可以创造性地使用掩码实现特殊功能5.1 局部注意力窗口限制每个token只能关注其附近一定范围内的tokendef create_local_mask(size, window_size3): mask torch.full((size, size), float(-inf)) for i in range(size): start max(0, i - window_size) end min(size, i window_size 1) mask[i, start:end] 0 return mask5.2 分层注意力在长文档处理中可以先对段落进行粗粒度注意力再在段落内进行细粒度注意力def create_hierarchical_mask(doc_structure): doc_structure: 列表表示每个段落的长度 例如[3, 5, 2]表示3个段落长度分别为3,5,2个token total_len sum(doc_structure) mask torch.zeros(total_len, total_len) offset 0 for seg_len in doc_structure: # 段内全连接 mask[offset:offsetseg_len, offset:offsetseg_len] 1 offset seg_len # 加上全局的段落级连接 segment_starts [sum(doc_structure[:i]) for i in range(len(doc_structure))] for i in segment_starts: for j in segment_starts: mask[i:, j:] 1 # 每个段落的第一个token可以看到其他段落的第一个token return mask5.3 多任务掩码在同时需要padding和因果掩码的场景如何组合两者PyTorch会自动处理这种组合# 同时使用两种掩码 output model( input, attn_maskcausal_mask, # 控制注意力方向 key_padding_maskpadding_mask # 过滤padding )6. 可视化调试技巧当掩码行为不符合预期时可视化是强大的调试工具def plot_attention(attention_weights, tokens, maskNone): import matplotlib.pyplot as plt fig, ax plt.subplots(figsize(10, 8)) cax ax.matshow(attention_weights, cmapviridis) if mask is not None: # 用红色标注被mask的位置 masked_positions mask float(-inf) for i in range(masked_positions.shape[0]): for j in range(masked_positions.shape[1]): if masked_positions[i, j]: ax.text(j, i, X, hacenter, vacenter, colorred) ax.set_xticks(range(len(tokens))) ax.set_yticks(range(len(tokens))) ax.set_xticklabels(tokens, rotation90) ax.set_yticklabels(tokens) fig.colorbar(cax) plt.show() # 示例使用 tokens [[CLS], Hello, world, [PAD]] attention torch.rand(4, 4) mask torch.tensor([ [0, 0, 0, 0], [0, 0, 0, -float(inf)], [0, 0, 0, -float(inf)], [0, -float(inf), -float(inf), -float(inf)] ]) plot_attention(attention, tokens, mask)7. 性能考量与最佳实践在实际项目中应用注意力掩码时还需要考虑以下工程因素设备转移确保掩码与模型在同一设备上mask mask.to(input.device) # 不要忘记!自动微分某些掩码操作可能破坏计算图避免在掩码创建过程中使用原生Python操作优先使用PyTorch内置函数如torch.triu内存占用对于极长序列全尺寸的attn_mask可能消耗大量内存考虑使用稀疏掩码或分块计算混合精度训练确保掩码张量的dtype与模型匹配-inf值在不同精度下可能有不同表示# 混合精度下的安全掩码创建 mask torch.triu(torch.ones(seq_len, seq_len), diagonal1) mask mask.masked_fill(mask 1, float(-inf)) mask mask.to(dtypetorch.float16) # 与模型精度一致

更多文章