从PointNet到PointNet++:手把手教你用PyTorch复现点云处理中的‘感受野’进化史

张开发
2026/4/18 11:03:43 15 分钟阅读

分享文章

从PointNet到PointNet++:手把手教你用PyTorch复现点云处理中的‘感受野’进化史
从PointNet到PointNet用PyTorch构建点云处理的层次化特征提取系统点云数据正逐渐成为计算机视觉和机器人领域的重要信息载体。与规则的二维图像不同点云直接保留了三维空间中的几何信息但同时也带来了处理上的独特挑战——如何有效提取无序、非均匀分布点集的特征PointNet系列算法为解决这一问题提供了开创性思路。本文将带您深入理解PointNet如何通过层次化结构克服PointNet的局限性并手把手实现关键模块。1. 点云处理的基础挑战与PointNet突破三维点云数据具有几个固有特性无序性点之间没有固定顺序、非结构化点与点之间没有网格连接关系以及密度不均匀性不同区域采样密度差异显著。这些特性使得传统的卷积神经网络无法直接应用。2017年提出的PointNet采用了一种巧妙的对称函数最大池化来解决无序性问题import torch import torch.nn as nn class PointNetBasic(nn.Module): def __init__(self, in_dim3, out_dim1024): super().__init__() self.mlp nn.Sequential( nn.Linear(in_dim, 64), nn.ReLU(), nn.Linear(64, 128), nn.ReLU(), nn.Linear(128, out_dim) ) def forward(self, x): # x: (B, N, 3) features self.mlp(x) # (B, N, out_dim) global_feature torch.max(features, dim1)[0] # (B, out_dim) return global_feature这个基础版本虽然能处理无序点集但存在明显局限局部特征缺失仅通过全局最大池化会丢失局部几何细节感受野单一无法像CNN那样构建多层次特征表示密度敏感对点云采样密度的变化鲁棒性不足提示PointNet的突破在于证明了深度网络可以直接处理点云但就像全连接网络处理图像一样缺乏局部特征提取能力。2. PointNet的核心架构设计PointNet通过分层特征学习解决了上述问题其核心思想是递归地在不同尺度上应用PointNet。整个网络由多个Set Abstraction(SA)层组成每个SA层包含三个关键组件2.1 最远点采样(FPS)算法实现采样层决定了特征提取的中心点位置。与随机采样相比FPS能更好地覆盖整个点云空间def farthest_point_sample(xyz, npoint): 输入: xyz: 点云数据 [B, N, 3] npoint: 采样点数量 返回: centroids: 采样点索引 [B, npoint] device xyz.device B, N, C xyz.shape centroids torch.zeros(B, npoint, dtypetorch.long).to(device) distance torch.ones(B, N).to(device) * 1e10 farthest torch.randint(0, N, (B,), dtypetorch.long).to(device) for i in range(npoint): centroids[:, i] farthest centroid xyz[torch.arange(B), farthest, :].view(B, 1, 3) dist torch.sum((xyz - centroid) ** 2, -1) mask dist distance distance[mask] dist[mask] farthest torch.max(distance, -1)[1] return centroids这种采样方式确保了即使点云密度不均匀选取的中心点也能较好地代表整体形状。2.2 球查询(ball query)分组策略分组层定义了每个中心点的局部邻域球查询相比KNN更适合处理密度变化def query_ball_point(radius, nsample, xyz, new_xyz): 输入: radius: 查询半径 nsample: 每个区域最大采样点数 xyz: 所有点坐标 [B, N, 3] new_xyz: 查询中心点坐标 [B, S, 3] 返回: group_idx: 分组索引 [B, S, nsample] device xyz.device B, N, C xyz.shape _, S, _ new_xyz.shape group_idx torch.arange(N, dtypetorch.long).view(1, 1, N).repeat([B, S, 1]).to(device) sqrdists square_distance(new_xyz, xyz) group_idx[sqrdists radius ** 2] N group_idx group_idx.sort(dim-1)[0][:, :, :nsample] group_first group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) mask group_idx N group_idx[mask] group_first[mask] return group_idx关键参数选择建议参数典型值影响半径r0.1-0.5决定局部区域大小值越大感受野越大nsample16-64控制局部区域点数量影响计算量2.3 层次化PointNet模块每个局部区域通过一个小型PointNet提取特征class PointNetSetAbstraction(nn.Module): def __init__(self, npoint, radius, nsample, in_channel, mlp): super().__init__() self.npoint npoint self.radius radius self.nsample nsample self.mlp_convs nn.ModuleList() self.mlp_bns nn.ModuleList() last_channel in_channel for out_channel in mlp: self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) self.mlp_bns.append(nn.BatchNorm2d(out_channel)) last_channel out_channel def forward(self, xyz, points): new_xyz index_points(xyz, farthest_point_sample(xyz, self.npoint)) idx query_ball_point(self.radius, self.nsample, xyz, new_xyz) grouped_xyz index_points(xyz, idx) grouped_xyz - new_xyz.view(B, S, 1, C) # 相对坐标 if points is not None: grouped_points index_points(points, idx) new_points torch.cat([grouped_xyz, grouped_points], dim-1) else: new_points grouped_xyz new_points new_points.permute(0, 3, 2, 1) for i, conv in enumerate(self.mlp_convs): bn self.mlp_bns[i] new_points F.relu(bn(conv(new_points))) new_points torch.max(new_points, 2)[0] new_xyz new_xyz.permute(0, 2, 1) return new_xyz, new_points3. 处理非均匀密度的进阶技术点云密度变化是实际应用中的常见挑战。PointNet提出了两种创新方法3.1 多尺度分组(MSG)同时提取多个尺度的特征并自适应融合class PointNetMSG(nn.Module): def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list): super().__init__() self.npoint npoint self.radius_list radius_list self.nsample_list nsample_list self.conv_blocks nn.ModuleList() self.bn_blocks nn.ModuleList() for i in range(len(mlp_list)): convs nn.ModuleList() bns nn.ModuleList() last_channel in_channel 3 for out_channel in mlp_list[i]: convs.append(nn.Conv2d(last_channel, out_channel, 1)) bns.append(nn.BatchNorm2d(out_channel)) last_channel out_channel self.conv_blocks.append(convs) self.bn_blocks.append(bns) def forward(self, xyz, points): new_xyz index_points(xyz, farthest_point_sample(xyz, self.npoint)) new_points_list [] for i, radius in enumerate(self.radius_list): nsample self.nsample_list[i] idx query_ball_point(radius, nsample, xyz, new_xyz) grouped_xyz index_points(xyz, idx) grouped_xyz - new_xyz.view(B, S, 1, C) if points is not None: grouped_points index_points(points, idx) grouped_points torch.cat([grouped_points, grouped_xyz], dim-1) else: grouped_points grouped_xyz grouped_points grouped_points.permute(0, 3, 2, 1) for j in range(len(self.conv_blocks[i])): conv self.conv_blocks[i][j] bn self.bn_blocks[i][j] grouped_points F.relu(bn(conv(grouped_points))) new_points torch.max(grouped_points, 2)[0] new_points_list.append(new_points) new_xyz new_xyz.permute(0, 2, 1) new_points torch.cat(new_points_list, dim1) return new_xyz, new_points3.2 多分辨率分组(MRG)更高效的特征融合方式通过结合不同层次的特征class PointNetMRG(nn.Module): def __init__(self, in_channel1, in_channel2, mlp): super().__init__() self.mlp_convs nn.ModuleList() self.mlp_bns nn.ModuleList() last_channel in_channel1 in_channel2 for out_channel in mlp: self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) self.mlp_bns.append(nn.BatchNorm1d(out_channel)) last_channel out_channel def forward(self, xyz1, points1, xyz2, points2): dists square_distance(xyz1, xyz2) dists, idx torch.topk(dists, 1, dim-1, largestFalse) dist_recip 1.0 / (dists 1e-8) norm torch.sum(dist_recip, dim2, keepdimTrue) weight dist_recip / norm interpolated_points torch.sum(index_points(points2, idx) * weight.view(B, N, 1), dim2) if points1 is not None: new_points torch.cat([points1, interpolated_points], dim1) else: new_points interpolated_points new_points new_points.permute(0, 2, 1) for i, conv in enumerate(self.mlp_convs): bn self.mlp_bns[i] new_points F.relu(bn(conv(new_points))) return new_points.permute(0, 2, 1)4. 完整网络实现与训练技巧结合上述模块我们可以构建完整的PointNet网络class PointNet2ClsMSG(nn.Module): def __init__(self, num_classes): super().__init__() self.sa1 PointNetMSG( npoint512, radius_list[0.1, 0.2, 0.4], nsample_list[16, 32, 128], in_channel0, mlp_list[[32, 32, 64], [64, 64, 128], [64, 96, 128]] ) self.sa2 PointNetMSG( npoint128, radius_list[0.2, 0.4, 0.8], nsample_list[32, 64, 128], in_channel64128128, mlp_list[[64, 64, 128], [128, 128, 256], [128, 128, 256]] ) self.sa3 PointNetSetAbstraction( npointNone, radiusNone, nsampleNone, in_channel128 256 256, mlp[256, 512, 1024] ) self.fc1 nn.Linear(1024, 512) self.bn1 nn.BatchNorm1d(512) self.drop1 nn.Dropout(0.4) self.fc2 nn.Linear(512, 256) self.bn2 nn.BatchNorm1d(256) self.drop2 nn.Dropout(0.4) self.fc3 nn.Linear(256, num_classes) def forward(self, xyz): l1_xyz, l1_points self.sa1(xyz, None) l2_xyz, l2_points self.sa2(l1_xyz, l1_points) l3_xyz, l3_points self.sa3(l2_xyz, l2_points) x l3_points.view(B, -1) x self.drop1(F.relu(self.bn1(self.fc1(x)))) x self.drop2(F.relu(self.bn2(self.fc2(x)))) x self.fc3(x) return x训练时的关键技巧数据增强随机旋转点云增加视角多样性随机缩放模拟不同距离的对象添加高斯噪声提高鲁棒性学习率策略scheduler torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max200, eta_min1e-5 )损失函数选择criterion nn.CrossEntropyLoss(label_smoothing0.1)在ModelNet40数据集上的典型训练曲线训练轮次训练准确率验证准确率5085.2%83.7%10089.6%87.2%20091.3%89.5%注意实际训练时建议使用验证集早停(early stopping)策略防止过拟合。

更多文章