PyTorch 学习笔记(10) : PyTorch torch.library

张开发
2026/4/18 19:45:28 15 分钟阅读

分享文章

PyTorch 学习笔记(10) : PyTorch torch.library
一、模块概述torch.library是 PyTorch 提供的用于扩展核心算子库的 API 集合主要功能包括功能说明测试自定义算子验证算子注册是否正确创建自定义算子在 Python 中定义新算子扩展现有算子为 C 注册的算子添加 Python 实现二、测试自定义算子2.1torch.library.opcheck()- 算子正确性检测核心作用验证自定义算子的元数据和属性是否符合 PyTorch 规范。检测内容test_schema# 检查 schema 与实现是否匹配test_autograd_registration# 检查自动微分是否正确注册test_faketensor# 检查 FakeTensor 内核是否正确test_aot_dispatch_dynamic# 检查与 torch.compile 的兼容性使用示例importtorchfromtorch.libraryimportcustom_opcustom_op(mylib::numpy_mul,mutates_args())defnumpy_mul(x:torch.Tensor,y:float)-torch.Tensor:x_npx.numpy(forceTrue)z_npx_np*yreturntorch.from_numpy(z_np).to(x.device)numpy_mul.register_fakedef_(x,y):returntorch.empty_like(x)# 测试多种输入场景sample_inputs[(torch.randn(3),3.14),(torch.randn(2,3,devicecuda),2.718),(torch.randn(1,10,requires_gradTrue),1.234),]forargsinsample_inputs:torch.library.opcheck(numpy_mul,args)# ✅ 通过则无输出⚠️注意opcheck与torch.autograd.gradcheck()互补前者测试 API 使用正确性后者测试梯度数学正确性。三、创建自定义算子3.1torch.library.custom_op()- 通用自定义算子核心参数参数说明name命名空间格式namespace::name如mylib::my_linearmutates_args被修改的参数名列表必须准确指定device_types支持的设备类型cpu、cuda等schema算子 schema 字符串推荐自动推断基础示例fromtorch.libraryimportcustom_opimportnumpyasnp# 示例1简单算子custom_op(mylib::numpy_sin,mutates_args())defnumpy_sin(x:torch.Tensor)-torch.Tensor:x_npx.cpu().numpy()y_npnp.sin(x_np)returntorch.from_numpy(y_np).to(devicex.device)# 示例2原地修改算子custom_op(mylib::numpy_sin_inplace,mutates_args{x},device_typescpu)defnumpy_sin_inplace(x:torch.Tensor)-None:x_npx.numpy()np.sin(x_np,outx_np)# 原地修改# 示例3工厂函数无输入 Tensorcustom_op(mylib::bar,mutates_args{},device_typescpu)defbar(device:torch.device)-torch.Tensor:returntorch.ones(3)3.2torch.library.triton_op()- Triton 内核封装适用场景当实现包含Triton 内核时使用允许torch.compile优化 Triton 代码。fromtorch.libraryimporttriton_op,wrap_tritonimporttritonfromtritonimportlanguageastltriton.jitdefadd_kernel(in_ptr0,in_ptr1,out_ptr,n_elements,BLOCK_SIZE:tl.constexpr):pidtl.program_id(axis0)block_startpid*BLOCK_SIZE offsetsblock_starttl.arange(0,BLOCK_SIZE)maskoffsetsn_elements xtl.load(in_ptr0offsets,maskmask)ytl.load(in_ptr1offsets,maskmask)tl.store(out_ptroffsets,xy,maskmask)triton_op(mylib::add,mutates_args{})defadd(x:torch.Tensor,y:torch.Tensor)-torch.Tensor:outputtorch.empty_like(x)n_elementsoutput.numel()defgrid(meta):return(triton.cdiv(n_elements,meta[BLOCK_SIZE]),)# 必须用 wrap_triton 包装wrap_triton(add_kernel)[grid](x,y,output,n_elements,16)returnoutput# 可与 torch.compile 配合使用torch.compiledeff(x,y):returnadd(x,y)3.3torch.library.wrap_triton()- Triton 内核追踪包装作用使 Triton 内核可被make_fx或torch.export捕获到计算图中。四、扩展现有算子4.1register_kernel()- 注册设备特定实现custom_op(mylib::numpy_sin,mutates_args(),device_typescpu)defnumpy_sin(x:torch.Tensor)-torch.Tensor:x_npx.numpy()y_npnp.sin(x_np)returntorch.from_numpy(y_np)# 为 CUDA 添加实现torch.library.register_kernel(mylib::numpy_sin,cuda)def_(x):x_npx.cpu().numpy()y_npnp.sin(x_np)returntorch.from_numpy(y_np).to(devicex.device)4.2register_fake()/impl_abstract()- 注册 FakeTensor 实现别名impl_abstract在 PyTorch 2.4 后重命名为register_fake。作用定义算子在无数据 TensorFakeTensor/Meta Tensor上的行为支持编译和导出。# 示例1常规算子torch.library.register_fake(mylib::custom_linear)def_(x,weight,bias):assertx.dim()2andweight.dim()2assertx.shape[1]weight.shape[1]return(x weight.t())bias# 返回元数据正确的空 Tensor# 示例2数据依赖形状如 nonzerotorch.library.register_fake(mylib::custom_nonzero)def_(x):ctxtorch.library.get_ctx()# 获取上下文nnzctx.new_dynamic_size()# 创建动态符号整数returnx.new_empty([nnz,x.dim()],dtypetorch.int64)4.3register_autograd()- 注册反向传播custom_op(mylib::numpy_sin,mutates_args())defnumpy_sin(x:torch.Tensor)-torch.Tensor:x_npx.cpu().numpy()y_npnp.sin(x_np)returntorch.from_numpy(y_np).to(devicex.device)defsetup_context(ctx,inputs,output):x,inputs ctx.save_for_backward(x)# 保存前向所需数据defbackward(ctx,grad):x,ctx.saved_tensorsreturngrad*x.cos()# 返回梯度torch.library.register_autograd(mylib::numpy_sin,backward,setup_contextsetup_context)# 测试xtorch.randn(3,requires_gradTrue)ynumpy_sin(x)grad_xtorch.autograd.grad(y,x,torch.ones_like(y))[0]4.4register_autocast()- 自动类型转换custom_op(mylib::my_sin,mutates_args())defmy_sin(x:torch.Tensor)-torch.Tensor:returntorch.sin(x)# 注册 CUDA 下的 FP16 自动转换torch.library.register_autocast(mylib::my_sin,cuda,torch.float16)# 使用withtorch.autocast(cuda,dtypetorch.float16):ytorch.ops.mylib.my_sin(x)# x 自动转为 fp16输出也是 fp164.5register_vmap()- 批量映射支持torch.library.register_vmap(mylib::numpy_mul)defnumpy_mul_vmap(info,in_dims,x,y):x_bdim,y_bdimin_dims# 调整维度进行广播计算xx.movedim(x_bdim,-1)ifx_bdimisnotNoneelsex.unsqueeze(-1)yy.movedim(y_bdim,-1)ify_bdimisnotNoneelsey.unsqueeze(-1)resultx*yreturnresult.movedim(-1,0),0# 返回 (输出, 输出维度)# 使用xtorch.randn(3)ytorch.randn(3)torch.vmap(numpy_mul)(x,y)# 批量映射4.6register_torch_dispatch()- TorchDispatch 规则classMyMode(torch.utils._python_dispatch.TorchDispatchMode):def__torch_dispatch__(self,func,types,args(),kwargsNone):returnfunc(*args,**kwargs)torch.library.register_torch_dispatch(mylib::foo,MyMode)def_(mode,func,types,args,kwargs):x,argsreturnx1# 在 MyMode 下行为改变# 测试withMyMode():yfoo(x)# 输出 x 1五、辅助工具5.1infer_schema()- 从类型注解推断 Schemadeffoo_impl(x:torch.Tensor)-torch.Tensor:returnx.sin()schematorch.library.infer_schema(foo_impl,op_namefoo,mutates_args{})print(schema)# 输出: foo(Tensor x) - Tensor5.2get_ctx()- 获取 Fake 实现上下文仅在register_fake内部有效用于创建动态符号尺寸。5.3get_kernel()- 获取已注册的内核# 获取 aten::add 的 CPU 内核kerneltorch.library.get_kernel(aten::add.Tensor,CPU)# 用于实现条件分发original_sintorch.library.get_kernel(aten::sin,CPU)defconditional_sin(dispatch_keys,x):if(x0).any():returnoriginal_sin.call_boxed(dispatch_keys,x)returntorch.zeros_like(x)六、底层 APILibrary 类警告建议优先使用上述高级 API底层 API 需要理解 PyTorch Dispatcher 机制。# 创建库my_libtorch.library.Library(mylib,DEF)# 定义新算子my_libtorch.library.Library(aten,IMPL)# 扩展现有算子# 定义算子my_lib.define(sum(Tensor self) - Tensor)# 注册实现my_lib.impl(div.Tensor,div_cpu,CPU)# 注册 fallbackmy_lib.fallback(fallback_kernel,Autocast)七、最佳实践总结场景推荐 API快速创建自定义算子custom_op()包装 Triton 内核triton_op()wrap_triton()支持 torch.compile必须实现register_fake()支持自动微分register_autograd()支持多设备register_kernel()按设备注册支持 torch.vmapregister_vmap()支持自动混合精度register_autocast()验证算子正确性opcheck()gradcheck()八、参考链接PyTorch Custom Operators 完整指南PyTorch Dispatcher 博客提示本文基于 PyTorch 2.11 官方文档整理建议结合官方最新文档使用。

更多文章