实验探索:浅浅尝试多尺度Loss优化Qwen-Reranker效果

张开发
2026/4/12 4:45:10 15 分钟阅读

分享文章

实验探索:浅浅尝试多尺度Loss优化Qwen-Reranker效果
实验探索浅浅尝试多尺度Loss优化Qwen-Reranker效果一、背景与动机在 RAG检索增强生成系统中Reranker 扮演着精排的角色对检索结果进行二次排序。传统的 Reranker 训练通常采用单一的 Binary Cross Entropy Loss即PointWiseLoss或者ListWiseLoss但在实际场景中我们往往面临两个挑战正负样本不平衡正样本相关文档通常远少于负样本排序一致性我们不仅希望模型正确分类更希望同一 query 下的相关文档得分高于不相关文档局部与全局问题很多时候每个Batch的局部最优拉到全局并不一定最优尤其在数据质量不够高且数据量很大时候本文尝试通过多尺度 Loss 设计来同时解决这两个问题。二、核心实现2.1 整体架构基于 HuggingFaceTrainer实现自定义训练器支持四种 Loss 模式Loss 类型用途特点pointwise基础分类标准 Cross Entropyfocal样本不平衡自动降权简单样本listwise排序一致性Batch 内组间排序global_consistent综合方案Focal Listwise 组合2.2 训练器实现classRerankerTrainer(Trainer):def__init__(self,yes_token_id,no_token_id,loss_typepointwise,temperature0.05,focal_alpha0.25,focal_gamma2.0,listwise_weight0.1,*args,**kwargs):super().__init__(*args,**kwargs)self.yes_token_idyes_token_id self.no_token_idno_token_id self.loss_typeloss_type self.temperaturetemperature self.focal_alphafocal_alpha self.focal_gammafocal_gamma self.listwise_weightlistwise_weight self._step_accuracy_sum0.0self._step_accuracy_count0defcompute_loss(self,model,inputs,return_outputsFalse,num_items_in_batchNone):labelsinputs.pop(labels)# [batch]# 1. Forward passoutputsmodel(**inputs)logitsoutputs.logits# [batch, seq_len, vocab_size]# 2. 提取分类用的 Logits (基于 yes/no token)last_logitslogits[:,-1,:]# 取最后一个 tokenyes_logitslast_logits[:,self.yes_token_id]no_logitslast_logits[:,self.no_token_id]# 构造二分类 logits: [batch, 2]# index 0 为 no(非同款), index 1 为 yes(同款)binary_logitstorch.stack([no_logits,yes_logits],dim1)# 3. 根据 loss_type 计算 lossifself.loss_typepointwise:lossself._pointwise_loss(binary_logits,labels)elifself.loss_typefocal:lossself._focal_loss(binary_logits,labels)elifself.loss_typelistwise:lossself._listwise_loss(binary_logits,labels)elifself.loss_typeglobal_consistent:lossself._global_consistent_loss(binary_logits,labels)else:lossself._pointwise_loss(binary_logits,labels)# 4. 计算准确率并记录withtorch.no_grad():predstorch.argmax(binary_logits,dim1)correct(predslabels).sum().item()accuracycorrect/labels.size(0)self._step_accuracy_sumaccuracy self._step_accuracy_count1ifself.state.global_step%self.args.logging_steps0andself._step_accuracy_count0:avg_accuracyself._step_accuracy_sum/self._step_accuracy_count self.log({train_accuracy:avg_accuracy})self._step_accuracy_sum0.0self._step_accuracy_count0return(loss,outputs)ifreturn_outputselselossdef_pointwise_loss(self,binary_logits,labels):标准 Cross Entropy LossreturnF.cross_entropy(binary_logits,labels)def_focal_loss(self,binary_logits,labels):Focal Loss处理样本不平衡ce_lossF.cross_entropy(binary_logits,labels,reductionnone)pttorch.exp(-ce_loss)# 预测正确的概率focal_lossself.focal_alpha*(1-pt)**self.focal_gamma*ce_lossreturnfocal_loss.mean()def_listwise_loss(self,binary_logits,labels):Listwise Loss处理 Batch 内排序一致性loss_listwisetorch.tensor(0.0,devicebinary_logits.device)pos_indicestorch.nonzero(labels1).squeeze(-1)group_count0iflen(pos_indices)0:foridxinpos_indices:# 找到当前组的范围假设数据按组排列正样本在前后面是负样本startidx.item()# 寻找下一组的开始next_postorch.nonzero(labels[start1:]1)endstart1next_pos[0].item()iflen(next_pos)0elselabels.size(0)group_logitsbinary_logits[start:end,1]# 只取 yes 的得分进行组内对比ifgroup_logits.size(0)1:# Listwise 目标组内第一个正样本得分最高targettorch.tensor([0],devicebinary_logits.device)loss_listwiseloss_listwiseF.cross_entropy((group_logits/self.temperature).unsqueeze(0),target)group_count1ifgroup_count0:returnloss_listwise/group_countelse:# Fallback to pointwisereturnF.cross_entropy(binary_logits,labels)def_global_consistent_loss(self,binary_logits,labels):全局一致性 LossFocal Listwise# Loss A: Focal Loss (处理正负样本失衡)loss_pointwiseself._focal_loss(binary_logits,labels)# Loss B: Listwise Loss (处理 Batch 内排序一致性)loss_listwiseself._listwise_loss(binary_logits,labels)returnloss_pointwiseself.listwise_weight*loss_listwise三、Loss 设计思路3.1 Focal Loss应对样本不平衡Focal Loss 通过动态调整样本权重让模型更关注难分类的样本FL(pt)−αt(1−pt)γlog⁡(pt)\text{FL}(p_t) -\alpha_t (1 - p_t)^\gamma \log(p_t)FL(pt​)−αt​(1−pt​)γlog(pt​)α\alphaα正样本权重平衡正负样本比例γ\gammaγ聚焦参数控制难易样本的权重差异3.2 Listwise Loss优化排序一致性传统 pointwise loss 只关心单个样本的分类正确性但 Reranker 的核心目标是排序。Listwise Loss 确保同一 query 下正样本得分 负样本得分通过 temperature 参数控制排序 margin 的锐度3.3 Global Consistent Loss鱼和熊掌兼得组合 Focal Listwise同时优化分类准确率和排序一致性LglobalLfocalλ⋅LlistwiseL_{global} L_{focal} \lambda \cdot L_{listwise}Lglobal​Lfocal​λ⋅Llistwise​其中λ\lambdaλ控制 listwise 项的权重。四、使用建议数据组织确保 batch 内数据按(query, positive_doc, negative_docs...)排列超参调优focal_alpha0.25,focal_gamma2.0是经典配置listwise_weight0.1起根据排序指标调整temperature越小排序越激进但可能不稳定五、总结这次尝试的核心收获多尺度 Loss 设计能有效兼顾不同训练目标Focal Loss 对样本不平衡场景有显著帮助Listwise Loss 提升了排序一致性但依赖数据组织方式部分基于AI生成细节大家可以讨论

更多文章