测试几种不同位置编码中旋转的耗时分析

张开发
2026/4/11 22:41:18 15 分钟阅读

分享文章

测试几种不同位置编码中旋转的耗时分析
importtorchimporttime# 四个版本 defrotate_half_original(x):x1,x2x.chunk(2,dim-1)returntorch.cat((-x2,x1),dim-1)defrotate_half_v1(x:torch.Tensor)-torch.Tensor:dx.size(-1)//2x1x[...,:d]x2x[...,d:]returntorch.cat((-x2,x1),dim-1)defrotate_half_v2(x:torch.Tensor)-torch.Tensor:dx.size(-1)//2outx.clone()out[...,:d]-x[...,d:]out[...,d:]x[...,:d]returnoutdefrotate_half_v3(x:torch.Tensor)-torch.Tensor:dx.shape[-1]//2a,bx.split(d,dim-1)returntorch.concat((-b,a),dim-1)# 速度测试 defbenchmark(func,x,name,steps100000):torch.cuda.synchronize()iftorch.cuda.is_available()elseNonestarttime.time()for_inrange(steps):func(x)torch.cuda.synchronize()iftorch.cuda.is_available()elseNoneprint(f{name:18s}|{time.time()-start:.4f}s)# 重点用大张量测试 if__name____main__:print(*60)print( 真实大张量测试Transformer 真实场景)print(*60)# 真实场景batch, seq_len, hidden_dimxtorch.randn(32,128,512)# 大张量关键在这里benchmark(rotate_half_original,x,原始版本)benchmark(rotate_half_v1,x,改进版 v1)benchmark(rotate_half_v2,x, 改进版 v2最快)benchmark(rotate_half_v3,x,改进版 v3)print(*60)

更多文章