心中无码,自然高清 || 联合去马赛克与超分辨率研究论文Pytorch复现
萌新GitHub项目地址: DRNFJDSR
本文结构
- 简单扫盲
- 什么是去马赛克
- 什么是超分辨率
- 《Deep Residual Network for Joint Demosaicing and Super-Resolution》论文简介
- 论文创新点
- 论文模型结构
- 训练数据
- 论文模型效果
- 论文复现
- Pytorch代码
- Model
- DataSet
- Train
- 需要注意的细节
- 复现结果
- 数值结果
- 图片展示
一、简单扫盲
1、什么是去马赛克
首先,去马赛克嘛,大家都知道:
当然不是上图这样的,各位读者姥爷别想歪了,此马赛克非彼马赛克,这个去马赛克是数码相机成像中的一个关键性的环节。要说明白这个得从数码相机的感光元件说起。
我们知道,数码图像是由像素排列成的,而一个像素点是由RGB即红、绿、蓝三种颜色混合而成的,而数码相机的感光元件只能感受到光照的强度,要想在一个点上同时采集红、绿、蓝三种颜色的光照强度,在结构和制作成本上会是一场噩梦。这个问题该如何解决呢?
这个时候 布莱斯.拜尔 拿着自己发明的Bayer阵列振臂疾呼:弟弟们,大哥来救你们了!
https:// zh.wikipedia.org/wiki/% E6%8B%9C%E7%88%BE%E6%BF%BE%E8%89%B2%E9%8F%A1
拜耳阵列(拜耳马赛克)简介 - 风之盔的博客 - CSDN博客
Bayer阵列的思路很简单,既然在一个点上采三种光很难,那就只采一种光呗,何必为难感光元件?既然我们又必须采集到三种不同颜色的光,那么就在感光的排列上做做文章呗:
采集到每个点只能采集到三种颜色的光中的一种,其它两种颜色的光则可以向邻居借得到,而这“借”的过程,我们就称之为“去马赛克”:
看了这上面的图,知道为啥叫“去马赛克”了吗?
相关的算法有FlexISP、ADMM、DemosaicNet等。
Deep Demosaicking
2、什么是超分辨率?
简而言之,就是把低分辨率的图像变成高分辨率的:
深度学习的超分辨率方法已有很多,如SRCNN、FSRCNN、ESPCN、VDSR等。
桂花糖:从SRCNN到EDSR,总结深度学习端到端超分辨率方法发展历程
二、《Deep Residual Network for Joint Demosaicing and Super-Resolution》论文简介
下载地址: Deep Residual Network for Joint Demosaicing and Super-Resolution
1、论文创新点
该论文的最大创新点和其标题一样,是第一次把去马赛克和超分辨率结合在一起做,直接从单通道的RAW图像中挖掘尽可能多的信息,直接生成超分辨率的三通道图片。相对于先做去马赛克,再做超分辨率,这样做的好处在于一可避免两个阶段的错误积累,产生质量更高的图片,二可减少运算量,减少计算时间。
2、论文模型结构
模型分为三个阶段:
a、提取颜色:用4x4的卷积,达到在Bayer图像中提取每个点真实颜色的目的
b、非线性映射:借鉴残差网络的模块构成深层网络提取特征
c、图像重构:借鉴ESPCN里的sub-pixel结构,将通道数减少4倍从而使得图像的高和宽分别提升两倍,达到超分辨率的目的
在论文中
a、Feature map的数量C=256。
b、采用的残差网络块的结构如下图,论文采用24个模块:
c、Sub-Pixel可参考ESPCN:
d、Batch Size为16x3x64x64
e、Learning Rate 每10000个batch降低一半
3、训练用的数据集
采用的是RAISE数据集中的6000张高清图片:
下载地址: The Raw Images Dataset
对这些图片的处理如图所示:
1、将16MP的原始TIFF图像经过三次factor=1.25的resize后变成4MP的TIFF图像
2、将4MP的TIFF图像经过一次factor=2 的resize后变成1MP的TIFF图像
3、将1MP的图像,对于每个像素,抹去G、B,R、B,R,B通道的数据仅留下一个与Bayer阵列相匹配的通道,形成Bayer图像(类似下图),然后将三通道合并成一个通道。
4、至此,训练集已经制作完成,data为1MP的Bayer图像,label是步骤2产生的4MP图像。
4、论文模型效果
三、论文复现
1、Pytorch代码:
1.1、Model:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
# ResNet
# https://blog.csdn.net/sunqiande88/article/details/80100891
class ResidualBlock(nn.Module):
def __init__(self):
super(ResidualBlock, self).__init__()
self.left = nn.Sequential(
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=True),
nn.PReLU(),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=True),
self.shortcut = nn.Sequential()
self.active_f = nn.PReLU()
def forward(self, x):
out = self.left(x)
out += self.shortcut(x)
out = self.active_f(out)
return out
class Net(nn.Module):
def __init__(self, resnet_level=2):
super(Net, self).__init__()
# ***Stage1***
# class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
self.stage1_1_conv4x4 = nn.Conv2d(in_channels=1, out_channels=256,
kernel_size=4, stride=2, padding=1, bias=True)
# Reference:
# CLASS torch.nn.PixelShuffle(upscale_factor)
# Examples:
# >>> pixel_shuffle = nn.PixelShuffle(3)
# >>> input = torch.randn(1, 9, 4, 4)
# >>> output = pixel_shuffle(input)
# >>> print(output.size())
# torch.Size([1, 1, 12, 12])
self.stage1_2_SP_conv = nn.PixelShuffle(2)
self.stage1_2_conv4x4 = nn.Conv2d(in_channels=64, out_channels=256,
kernel_size=3, stride=1, padding=1, bias=True)
# CLASS torch.nn.PReLU(num_parameters=1, init=0.25)
self.stage1_2_PReLU = nn.PReLU()
# ***Stage2***
self.stage2_ResNetBlock = []
for i in range(resnet_level):
self.stage2_ResNetBlock.append(ResidualBlock())
self.stage2_ResNetBlock = nn.Sequential(*self.stage2_ResNetBlock)
# ***Stage3***
self.stage3_1_SP_conv = nn.PixelShuffle(2)
self.stage3_2_conv3x3 = nn.Conv2d(in_channels=64, out_channels=256,
kernel_size=3, stride=1, padding=1, bias=True)
self.stage3_2_PReLU = nn.PReLU()
self.stage3_3_conv3x3 = nn.Conv2d(in_channels=256, out_channels=3,
kernel_size=3, stride=1, padding=1, bias=True)
def forward(self, x):
out = self.stage1_1_conv4x4(x)
out = self.stage1_2_SP_conv(out)
out = self.stage1_2_conv4x4(out)
out = self.stage1_2_PReLU(out)
out = self.stage2_ResNetBlock(out)
out = self.stage3_1_SP_conv(out)
out = self.stage3_2_conv3x3(out)
out = self.stage3_2_PReLU(out)
out = self.stage3_3_conv3x3(out)
return out
1.2、DataSet:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.utils.data as data
from PIL import Image
import random
import numpy as np
# Reference link:
# 如何构建数据集
# https://oidiotlin.com/create-custom-dataset-in-pytorch/
# https://www.pytorchtutorial.com/pytorch-custom-dataset-examples/
# transforms 函数的使用
# https://www.jianshu.com/p/13e31d619c15
# ToTensor:convert a PIL image to tensor (H*W*C) in range [0,255] to a torch.Tensor(C*H*W) in the range [0.0,1.0]
# torch.set_default_tensor_type('torch.DoubleTensor')
class CustomDataset(data.Dataset):
# file_path TXT文件路径
# random_augment=1 随机裁剪数据增强
# block_size=64 裁剪大小
def __init__(self, file_path, block_size=64):
with open(file_path, 'r') as file:
self.imgs = list(map(lambda line: line.strip().split(' '), file))
self.Block_size = block_size
print("DataSet Size is: ", self.__len__())
# print(len(self.imgs))
# for i in self.imgs:
# print(len(i))
def __getitem__(self, index):
# 注意!!! 读入的Bayer图像最左上为:
# R G
# G B
# Reference API
# class torchvision.transforms.RandomCrop(size, padding=0, pad_if_needed=False)
# class torchvision.transforms.Compose([transforms_list,])->生成一个函数
data_path, label_path = self.imgs[index]
# print(index, data_path, label_path)
data = Image.open(data_path).convert('L')
label = Image.open(label_path).convert('RGB')
trans = transforms.Compose([transforms.ToTensor()])
data_img = trans(data)
label_img = trans(label)
return data_img, label_img
def __len__(self):
return len(self.imgs)
1.3、Train:
import torch
import torch.utils.data as data
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import time
from PIL import Image
from DataSet import CustomDataset
from NewResNet import Net
from multiprocessing import Process
from Test_class import Run_test
# *** 超参数*** `
Parameter_path = './Final_train_LR.txt'
MODEL_PATH = './Final_Model.pkl'
EPOCH = 1
HALF_LR_STEP = 40000
LR = 0.0001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 训练集与测试集的路径
train_data_path = "./8K_TRAIN_DATA/8K_TRAIN_DATA.txt"
test_data_path = "./8K_CROSS_DATA/8K_CROSS_DATA.txt"
BATCH_BLOCK_SIZE = 64
BATCH_SIZE = 8
DATA_SHUFFLE = True
# 检查GPU是否可用
print("cuda:", torch.cuda.is_available(), "GPUs", torch.cuda.device_count())
# 保存和恢复模型
# https://www.cnblogs.com/nkh222/p/7656623.html
# https://blog.csdn.net/quincuntial/article/details/78045036
# torch.save(the_model.state_dict(), PATH)
# the_model = TheModelClass(*args, **kwargs)
# the_model.load_state_dict(torch.load(PATH))
# # 只保存网络的参数, 官方推荐的方式
# torch.save(net.state_dict(), 'net_params.pkl')
## 加载网络参数
# net.load_state_dict(torch.load('net_params.pkl'))
print("Loading the LR...")
try:
P = open(Parameter_path)
P = list(P)
LR = float(P[0])
except:
print("Loading LR fail...")
print("Loading the saving Model...")
MyNet = Net(24).to(device)
try:
MyNet.load_state_dict(torch.load(MODEL_PATH))
except:
print("Loading Fail.")
print("Loading the Training data...")
MyData = CustomDataset(file_path=train_data_path,
block_size=BATCH_BLOCK_SIZE)
# CLASS torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False,
# sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>,
# pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
train_data = data.DataLoader(dataset=MyData,
batch_size=BATCH_SIZE,
shuffle=DATA_SHUFFLE)
# CLASS torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
Optimizer = torch.optim.Adam(MyNet.parameters(), lr=LR, betas=(0.9, 0.999), eps=1e-08)
# CLASS torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean')
Loss_Func = nn.MSELoss()
counter = 0
print("Start training...")
for epoch in range(EPOCH):
for step, (data, label) in enumerate(train_data):
counter = counter + 1
if counter != 0 and counter % HALF_LR_STEP == 0:
LR = LR / 2
Optimizer = torch.optim.Adam(MyNet.parameters(), lr=LR, betas=(0.9, 0.999), eps=1e-08)
with open(Parameter_path, 'w') as f:
f.write(str(LR))
print('LR:', LR)
data, label = data.to(device), label.to(device)
start = time.perf_counter()
out = MyNet(data)
# print(type(out), out.shape)
loss = Loss_Func(out, label)
Optimizer.zero_grad()
loss.backward()
Optimizer.step()
print(loss)