整体框架

SR,即super resolution,即超分辨率。CNN相对来说比较著名,就是卷积神经网络了。从名字可以看出,SRCNN是首个应用于超分辨领域的卷积神经网络,事实上也的确如此。

所谓超分辨率,就是把低分辨率(LR, Low Resolution)图片放大为高分辨率(HR, High Resolution)的过程。由于是开山之作,SRCNN相对比较简单,总共分三步

  1. 输入LR图像 超分辨网络SRCNN的Pytorch实现 ,经双三次(bicubic)插值,被放大成目标尺寸,得到 超分辨网络SRCNN的Pytorch实现
  2. 通过三层卷积网络拟合非线性映射
  3. 输出HR图像结果 超分辨网络SRCNN的Pytorch实现

训练的目标损失是最小化SR图像 超分辨网络SRCNN的Pytorch实现 和原高分辨率图像 超分辨网络SRCNN的Pytorch实现 像素差的均方误差

超分辨网络SRCNN的Pytorch实现

其中, 超分辨网络SRCNN的Pytorch实现 为训练样本数,参数更新公式为

超分辨网络SRCNN的Pytorch实现

网络模型

其网络结构如下

如前所述,网络分为三个卷积层

  1. 维度是 超分辨网络SRCNN的Pytorch实现 ,表示输入图像通道数为1,进行卷积运算的核尺寸为 超分辨网络SRCNN的Pytorch实现 ,输出深度为64。
  2. 维度是 超分辨网络SRCNN的Pytorch实现 ,64即上一层输出,32为下一层输出。
  3. 尺寸为 超分辨网络SRCNN的Pytorch实现 。它的输出是单通道图像,与输入相同。

所以这个模型实现起来并不难

# models.py
class SRCNN(nn.Module):
    def __init__(self, nChannel=1):
        super(SRCNN,self).__init__()
        self.conv1 = nn.Conv2d(nChannel, 64,
            kernel_size=9, padding=9//2)
        self.conv2 = nn.Conv2d(64, 32,
            kernel_size=5, padding=5//2)
        self.conv3 = nn.Conv2d(32, nChannel, 
            kernel_size=5, padding=5//2)
        self.relu = nn.ReLU(inplace=True)
    def forward(self,x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        return x

数据集

训练数据集可手动生成,设放大倍数为scale,考虑到原始数据未必会被scale整除,所以要重新规划一下图像尺寸,所以训练数据集的生成分为三步:

  1. 将原始图像通过双三次插值重设尺寸,使之可被scale整除,作为高分辨图像数据HR
  2. 将HR通过双三次插值压缩scale倍,为低分辨图像的原始数据
  3. 将低分辨图像通过双三次插值放大scale倍,与HR图像维度相等,作为低分辨图像数据LR

最后,可通过h5py将训练数据分块并打包,其生成代码为

import h5py
import PIL.Image as pImg
def rgb2gray(img):
    return 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
# imgPath为图像路径;h5Path为存储路径;scale为放大倍数
# pSize为patch尺寸; pStride为步长
def setTrianData(imgPath, h5Path, scale=3, pSize=33, pStride=14):
    h5_file = h5py.File(h5Path, 'w')
    lrPatches, hrPatches = [], []       #用于存储低分辨率和高分辨率的patch
    for p in sorted(glob.glob(f'{imgPath}/*')):
        hr = pImg.open(p).convert('RGB')
        lrWidth, lrHeight = hr.width // scale, hr.height // scale
        # width, height为可被scale整除的训练数据尺寸
        width, height = lrWidth*scale, lrHeight*scale
        hr = hr.resize((width, height), resample=pImg.BICUBIC)
        lr = hr.resize((lrWidth, lrHeight), resample=pImg.BICUBIC)
        lr = lr.resize((width, height), resample=pImg.BICUBIC)
        hr = np.array(hr).astype(np.float32)
        lr = np.array(lr).astype(np.float32)
        hr = rgb2gray(hr)
        lr = rgb2gray(lr)
        # 将数据分割
        for i in range(0, height - pSize + 1, pStride):
            for j in range(0, width - pSize + 1, pStride):
                lrPatches.append(lr[i:i + pSize, j:j + pSize])
                hrPatches.append(hr[i:i + pSize, j:j + pSize])
    h5_file.create_dataset('lr', data=np.array(lrPatches))
    h5_file.create_dataset('hr', data=np.array(hrPatches))
    h5_file.close()

以比较常见的T91数据集为例,通过上面的方法,可以得到一个181M的h5文件。

对预测数据执行相同的操作。

在做好训练数据之后,需要为这些数据创建一个读取类,以便torch中的DataLoader调用,而DataLoader中的内容则是Dataset,所以新建的读取类需要继承Dataset,并实现其__getitem__和__len__这两个成员方法。

这两个方法只是看上去吓人,但对Python稍有一点深入了解,就会知道__getitem__是字典索引的方法,而__len__则设定了len函数的返回值。

import h5py
import numpy as np
from torch.utils.data import Dataset
class DataSet(Dataset):
    def __init__(self, h5_file):
        super(Dataset, self).__init__()
        self.h5_file = h5_file
    def __getitem__(self, idx):
        with h5py.File(self.h5_file, 'r') as f:
            return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0)
    def __len__(self):
        with h5py.File(self.h5_file, 'r') as f:
            return len(f['lr'])

训练

首先,训练需要一点准备工作,比如数据集准备好,相关的文件夹需要建好,建好模型之后,需要采用什么样的优化方式。训练设备是用cpu还是cuda,然后将数据集和模型装载到设备上。

数据准备

import os
import copy
import torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
from models import SRCNN
trainFile = "91-image.h5"
evalFile = "Set5.h5"
cudnn.benchmark = True
# 设置训练设备 是CPU还是cuda
device = torch.device(
  'cuda:0' if torch.cuda.is_available() else 'cpu')
# 装载训练数据
trainData = Dataset(trainFile)
trainLoader = DataLoader(dataset=trainData,
  bSize=bSize,
  shuffle=True,               # 表示打乱样本
  num_workers=nWorker,        # 线程数
  pin_memory=True,            # 方便载入CUDA
  drop_last=True)
# 装载预测数据
evalDatas = Dataset(evalFile)
evalLoader = DataLoader(dataset=evalDatas, bSize=1)

模型准备

# 模型和设备
lr = 1e-4       #学习率
torch.manual_seed(seed)     #设置随机数种子
model = SRCNN().to(device)  #将模型载入设备
criterion = nn.MSELoss()    #设置损失函数
optimizer = optim.Adam([
  {'params': model.conv1.parameters()},
  {'params': model.conv2.parameters()},
  {'params': model.conv3.parameters(), 'lr': lr * 0.1}
], lr=lr)

train

outPath = "outputs"
scale = 3
bSize = 16
nEpoch = 400
nWorker = 8     #线程数
seed = 42       #随机数种子
def initPSNR():
    return {'avg':0, 'sum':0, 'count':0}
def updatePSNR(psnr, val, n=1):
    s = psnr['sum'] + val*n
    c = psnr['count'] + n
    return {'avg':s/c, 'sum':s, 'count':c}
bestWeights = copy.deepcopy(model.state_dict()) #最佳模型
bestEpoch = 0   #最佳训练结果
bestPSNR = 0.0  #最佳psnr
# 训练主循环
for epoch in range(nEpoch):
  model.train()
  epochLosses = initPSNR()
  for data in trainLoader:
      inputs, labels = data
      inputs = inputs.to(device)
      labels = labels.to(device)
      preds = model(inputs)
      loss = criterion(preds, labels)
      epochLosses = updatePSNR(epochLosses,loss.item(), len(inputs))
      optimizer.zero_grad()   #清空梯度
      loss.backward()         #反向传播
      optimizer.step()        #根据梯度更新网络参数
      print(f'{epochLosses['avg']:.6f}')
  torch.save(model.state_dict(), 
      os.path.join(outPath, f'epoch_{epoch}.pth'))
  model.eval()    #取消dropout
  psnr = AverageMeter()
  for data in evalLoader:
      inputs, labels = data
      inputs = inputs.to(device)
      labels = labels.to(device)
      # 令reqires_grad自动设为False,关闭自动求导
      # clamp将inputs归一化为0到1区间
      with torch.no_grad():
          preds = model(inputs).clamp(0.0, 1.0)
      tmp_psnr = 10. * torch.log10(
          1. / torch.mean((preds - labels) ** 2))
      psnr = updatePSNR(psnr, tmp_psnr, len(inputs))
  print(f'eval psnr: {psnr.avg:.2f}')
  if psnr['avg'] > bestPSNR:
      bestEpoch = epoch
      bestPSNR = psnr['avg']
      bestWeights = copy.deepcopy(model.state_dict())
print(f'best epoch: {bestEpoch}, psnr: {bestPSNR:.2f}')
torch.save(bestWeights, os.path.join(outPath, 'best.pth'))

最终结果是
原图来自网络