用PyTorch复现SRCNN:三行代码搞定图像超分,重温2015年的经典

张开发
2026/4/16 2:46:11 15 分钟阅读

分享文章

用PyTorch复现SRCNN:三行代码搞定图像超分,重温2015年的经典
用PyTorch复现SRCNN三行代码搞定图像超分重温2015年的经典在深度学习模型日益复杂的今天动辄数百层的网络架构已成为常态。然而回望2015年一个仅由三层卷积构成的模型——SRCNN却开创了深度学习在图像超分辨率领域的先河。本文将带你用PyTorch亲手实现这一经典模型体验其简洁之美与高效性能。1. SRCNN模型解析与PyTorch实现SRCNNSuper-Resolution Convolutional Neural Network的核心思想是将传统超分辨率方法中的三个关键步骤——特征提取、非线性映射和重建——统一到一个端到端的卷积神经网络中。这种设计不仅简化了流程还通过数据驱动的方式自动学习最优映射。1.1 模型架构详解SRCNN的网络结构异常简洁仅包含三个卷积层import torch.nn as nn class SRCNN(nn.Module): def __init__(self, num_channels1): super(SRCNN, self).__init__() self.conv1 nn.Conv2d(num_channels, 64, kernel_size9, padding4) self.conv2 nn.Conv2d(64, 32, kernel_size5, padding2) self.conv3 nn.Conv2d(32, num_channels, kernel_size5, padding2) self.relu nn.ReLU(inplaceTrue) def forward(self, x): x self.relu(self.conv1(x)) x self.relu(self.conv2(x)) x self.conv3(x) return x各层功能解析层输入通道输出通道核大小功能描述Conv11649×9提取局部图像特征Conv264325×5非线性特征映射Conv33215×5高分辨率图像重建提示对于彩色图像处理只需将num_channels参数设为3即可模型会自动适应RGB三通道输入。1.2 模型初始化技巧虽然SRCNN结构简单但合理的初始化对训练效果至关重要def weights_init(m): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, modefan_out, nonlinearityrelu) if m.bias is not None: nn.init.constant_(m.bias, 0) model SRCNN() model.apply(weights_init)2. 数据准备与预处理2.1 数据集选择与处理DIV2K是超分辨率任务中最常用的数据集之一包含800张训练图像和100张验证图像。我们可以使用TorchVision进行高效加载from torchvision import transforms transform transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean[0.5], std[0.5]) ]) class DIV2KDataset(Dataset): def __init__(self, hr_dir, lr_dir, scale2, transformNone): self.hr_images sorted(glob.glob(f{hr_dir}/*.png)) self.lr_images sorted(glob.glob(f{lr_dir}/*.png)) self.transform transform self.scale scale def __getitem__(self, idx): hr_img Image.open(self.hr_images[idx]) lr_img Image.open(self.lr_images[idx]) if self.transform: hr_img self.transform(hr_img) lr_img self.transform(lr_img) return lr_img, hr_img2.2 数据增强策略为提高模型泛化能力建议采用以下增强组合随机旋转90°, 180°, 270°水平/垂直翻转随机裁剪通常裁剪为48×48的小块色彩抖动针对彩色图像train_transform transforms.Compose([ transforms.RandomCrop(48), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.ToTensor(), transforms.Normalize(mean[0.5], std[0.5]) ])3. 模型训练与调优3.1 损失函数与优化器选择SRCNN通常使用L1或L2损失函数各有优劣L1 LossMAE对异常值更鲁棒收敛稳定L2 LossMSE强调大误差惩罚可能产生更锐利的结果criterion nn.L1Loss() # 或 nn.MSELoss() optimizer torch.optim.Adam(model.parameters(), lr1e-4) scheduler torch.optim.lr_scheduler.StepLR(optimizer, step_size50, gamma0.5)3.2 训练过程监控典型的训练循环实现def train(model, dataloader, criterion, optimizer, device): model.train() running_loss 0.0 for lr_imgs, hr_imgs in dataloader: lr_imgs lr_imgs.to(device) hr_imgs hr_imgs.to(device) optimizer.zero_grad() outputs model(lr_imgs) loss criterion(outputs, hr_imgs) loss.backward() optimizer.step() running_loss loss.item() return running_loss / len(dataloader)常见训练曲线分析理想情况训练和验证损失同步下降最终趋于平稳过拟合训练损失持续下降而验证损失开始上升欠拟合训练和验证损失都下降缓慢或停滞注意SRCNN训练通常需要100-300个epoch才能达到较好效果过早停止可能导致性能不佳。4. 模型应用与效果评估4.1 单图超分辨率实践训练完成后可以轻松将模型应用于自己的图像def enhance_image(model, image_path, device): img Image.open(image_path).convert(L) # 转为灰度 img_tensor transform(img).unsqueeze(0).to(device) with torch.no_grad(): output model(img_tensor) enhanced_img transforms.ToPILImage()(output.squeeze().cpu()) return enhanced_img4.2 性能评估指标常用超分辨率评估指标对比指标计算方式特点PSNR峰值信噪比计算简单与人类感知相关性一般SSIM结构相似性更符合人类视觉感知LPIPS学习感知相似性基于深度学习评估最准确from skimage.metrics import peak_signal_noise_ratio as psnr from skimage.metrics import structural_similarity as ssim def evaluate(hr_img, sr_img): psnr_value psnr(hr_img, sr_img, data_range1.0) ssim_value ssim(hr_img, sr_img, multichannelTrue, data_range1.0) return psnr_value, ssim_value4.3 实际应用技巧边缘处理对于边界区域可适当扩展padding大图处理对于大尺寸图像可分块处理再拼接多尺度增强可尝试不同放大倍数的级联处理def process_large_image(model, large_img, patch_size256, overlap32): patches split_into_patches(large_img, patch_size, overlap) enhanced_patches [] for patch in patches: enhanced model(patch) enhanced_patches.append(enhanced) return merge_patches(enhanced_patches, overlap)5. 进阶优化方向虽然SRCNN结构简单但仍有多种优化空间5.1 网络结构改进增加残差连接类似VDSR使用更高效的激活函数如PReLU引入注意力机制class EnhancedSRCNN(nn.Module): def __init__(self, num_channels1): super().__init__() self.conv1 nn.Conv2d(num_channels, 64, 9, padding4) self.prelu1 nn.PReLU() self.conv2 nn.Conv2d(64, 32, 5, padding2) self.prelu2 nn.PReLU() self.conv3 nn.Conv2d(32, num_channels, 5, padding2) self.attention nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(64, 64//8, 1), nn.ReLU(), nn.Conv2d(64//8, 64, 1), nn.Sigmoid() ) def forward(self, x): x self.prelu1(self.conv1(x)) attention self.attention(x) x x * attention x self.prelu2(self.conv2(x)) return self.conv3(x)5.2 训练策略优化渐进式学习率调整多阶段训练先低分辨率后高分辨率对抗训练引入GAN损失# 对抗训练示例 discriminator ... # 定义判别器 adv_criterion nn.BCEWithLogitsLoss() def adversarial_loss(real_pred, fake_pred): real_loss adv_criterion(real_pred, torch.ones_like(real_pred)) fake_loss adv_criterion(fake_pred, torch.zeros_like(fake_pred)) return (real_loss fake_loss) / 2在实际项目中我发现结合L1损失和感知损失使用VGG特征往往能取得更好的视觉效果。对于老照片修复可以先用SRCNN进行超分辨率处理再配合传统的去噪算法效果通常比单独使用任何一种方法都要好。

更多文章