在前一篇文章 【深度域自适应】一、DANN与梯度反转层(GRL)详解 中,我们主要讲解了DANN的网络架构与梯度反转层(GRL)的基本原理,接下来这篇文章中我们将主要复现DANN论文 Unsupervised Domain Adaptation by Backpropagation 中MNIST和MNIST-M数据集的迁移训练实验。
为了利用DANN实现MNIST和MNIST-M数据集的迁移训练,我们首先需要获取到MNIST和MNIST-M数据集。其中MNIST数据集很容易获取,官网下载链接为: MNSIT 。需要下载的文件如下图所示蓝色的4个文件。
由于tensorflow和keras深度融合,我们可以通过keras的相关API进行MNIST数据集,如下:
from tensorflow.keras.datasets import mnist # 导入MNIST数据集 (X_train,y_train),(X_test,y_test) = mnist.load_data()
MNIST-M数据集由MNIST数字与BSDS500数据集中的随机色块混合而成。那么要像生成MNIST-M数据集,请首先下载BSDS500数据集。BSDS500数据集的官方下载地址为: BSDS500 。 以下是BSDS500数据集官方网址相关截图,点击下图中蓝框的连接即可下载数据。
下载好BSDS500数据集后,我们必须根据MNIST和BSDS500数据集来生成MNIST-M数据集,生成数据集的脚本 create_mnistm.py 如下:
create_mnistm.py
# -*- coding: utf-8 -*- # @Time : 2021/7/24 下午1:50 # @Author : Dai Pu wei # @Email : 771830171@qq.com # @File : create_mnistm.py # @Software: PyCharm from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import tarfile import numpy as np import pickle as pkl import skimage.io import skimage.transform from tensorflow.keras.datasets import mnist rand = np.random.RandomState(42) def compose_image(mnist_data, background_data): 这是将MNIST数据和BSDS500数据进行融合成MNIST-M数据的函数 :param mnist_data: MNIST数据 :param background_data: BDSD500数据,作为背景图像 :return: # 随机融合MNIST数据和BSDS500数据 w, h, _ = background_data.shape dw, dh, _ = mnist_data.shape x = np.random.randint(0, w - dw) y = np.random.randint(0, h - dh) bg = background_data[x:x + dw, y:y + dh] return np.abs(bg - mnist_data).astype(np.uint8) def mnist_to_img(x): 这是实现MNIST数据格式转换的函数,0/1数据位转化为RGB数据集 :param x: 0/1格式MNIST数据 :return: x = (x > 0).astype(np.float32) d = x.reshape([28, 28, 1]) * 255 return np.concatenate([d, d, d], 2) def create_mnistm(X,background_data): 这是生成MNIST-M数据集的函数,MNIST-M数据集介绍可见: http://jmlr.org/papers/volume17/15-239/15-239.pdf :param X: MNIST数据集 :param background_data: BSDS500数据集,作为背景 :return: # 遍历所有MNIST数据集,生成MNIST-M数据集 X_ = np.zeros([X.shape[0], 28, 28, 3], np.uint8) for i in range(X.shape[0]): if i % 1000 == 0: print('Processing example', i) # 随机选择背景图像 bg_img = rand.choice(background_data) # 0/1数据位格式MNIST数据转换为RGB格式 mnist_image = mnist_to_img(X[i]) # 将MNIST数据和BSDS500数据背景进行融合 mnist_image = compose_image(mnist_image, bg_img) X_[i] = mnist_image return X_ def run_main(): 这是主函数 # 初始化路径 BST_PATH = os.path.abspath('./model_data/dataset/BSR_bsds500.tgz') mnist_dir = os.path.abspath("model_data/dataset/MNIST") mnistm_dir = os.path.abspath("model_data/dataset/MNIST_M") # 导入MNIST数据集 (X_train,y_train),(X_test,y_test) = mnist.load_data() # 加载BSDS500数据集 f = tarfile.open(BST_PATH) train_files = [] for name in f.getnames(): if name.startswith('BSR/BSDS500/data/images/train/'): train_files.append(name) print('Loading BSR training images') background_data = [] for name in train_files: fp = f.extractfile(name) bg_img = skimage.io.imread(fp) background_data.append(bg_img) except: continue # 生成MNIST-M训练数据集和验证数据集 print('Building train set...') train = create_mnistm(X_train,background_data) print(np.shape(train)) print('Building validation set...') valid = create_mnistm(X_test,background_data) print(np.shape(valid)) # 将MNIST数据集转化为RGB格式 X_train = np.expand_dims(X_train,-1) X_test = np.expand_dims(X_test,-1) X_train = np.concatenate([X_train,X_train,X_train],axis=3) X_test = np.concatenate([X_test,X_test,X_test],axis=3) y_train = np.array(y_train).astype(np.int32) y_test = np.array(y_test).astype(np.int32) # 保存MNIST数据集为pkl文件 if not os.path.exists(mnist_dir): os.mkdir(mnist_dir) with open(os.path.join(mnist_dir, 'mnist_data.pkl'), 'wb') as f: pkl.dump({'train': X_train, 'train_label': y_train, 'val': X_test, 'val_label':y_test}, f, pkl.HIGHEST_PROTOCOL) # 保存MNIST-M数据集为pkl文件 if not os.path.exists(mnistm_dir): os.mkdir(mnistm_dir) with open(os.path.join(mnistm_dir, 'mnist_m_data.pkl'), 'wb') as f: pkl.dump({'train': train, 'train_label':y_train, 'val': valid, 'val_label':y_test}, f, pkl.HIGHEST_PROTOCOL) # 计算数据集平均值,用于数据标准化 print(np.shape(X_train)) print(np.shape(X_test)) print(np.shape(train)) print(np.shape(valid)) print(np.shape(y_train)) print(np.shape(y_test)) pixel_mean = np.vstack([X_train,train,X_test,valid]).mean((0,1,2)) print(np.shape(pixel_mean)) print(pixel_mean) if __name__ == '__main__': run_main()
由于整个DANN-MNIST网络的训练过程中涉及到很多超参数,因此为了整个项目的编程方便,我们利用面向对象的思想将所有的超参数放置到一个类中,即参数配置类config。这个参数配置类config的代码如下:
# -*- coding: utf-8 -*- # @Time : 2020/2/15 15:05 # @Author : Dai PuWei # @Email : 771830171@qq.com # @File : config.py # @Software: PyCharm import os class config(object): __defualt_dict__ = { "pre_model_path":None, "checkpoints_dir":os.path.abspath("./checkpoints"), "logs_dir":os.path.abspath("./logs"), "config_dir":os.path.abspath("./config"), "image_input_shape":(28,28,3), "image_size":28, "init_learning_rate": 1e-2, "momentum_rate":0.9, "batch_size":256, "epoch":500, "pixel_mean":[45.652287,45.652287,45.652287], def __init__(self,**kwargs): 这是参数配置类的初始化函数 :param kwargs: 参数字典 # 初始化相关配置参数 self.__dict__.update(self. __defualt_dict__) # 根据相关传入参数进行参数更新 self.__dict__.update(kwargs) if not os.path.exists(self.checkpoints_dir): os.makedirs(self.checkpoints_dir) if not os.path.exists(self.logs_dir): os.makedirs(self.logs_dir) if not os.path.exists(self.config_dir): os.makedirs(self.config_dir) def set(self,**kwargs): 这是参数配置的设置函数 :param kwargs: 参数字典 :return: # 根据相关传入参数进行参数更新 self.__dict__.update(kwargs) def save_config(self,time): 这是保存参数配置类的函数 :param time: 时间点字符串 :return: # 更新相关目录 self.checkpoints_dir = os.path.join(self.checkpoints_dir,time) self.logs_dir = os.path.join(self.logs_dir,time) self.config_dir = os.path.join(self.config_dir,time) if not os.path.exists(self.config_dir): os.makedirs(self.config_dir) if not os.path.exists(self.checkpoints_dir): os.makedirs(self.checkpoints_dir) if not os.path.exists(self.logs_dir): os.makedirs(self.logs_dir) config_txt_path = os.path.join(self.config_dir,"config.txt") with open(config_txt_path,'a') as f: for key,value in self.__dict__.items(): if key in ["checkpoints_dir","logs_dir","config_dir"]: value = os.path.join(value,time) s = key+": "+value+"\n" f.write(s)
在DANN中比较重要的模块就是梯度反转层(Gradient Reversal Layer, GRL)的实现。GRL的tf2.x代码实现如下:
import tensorflow as tf from tensorflow.keras.layers import Layer @tf.custom_gradient def gradient_reversal(x,alpha=1.0): def grad(dy): return -dy * alpha, None return x, grad class GradientReversalLayer(Layer): def __init__(self,**kwargs): 这是梯度反转层的初始化函数 :param kwargs: 参数字典 super(GradientReversalLayer,self).__init__(kwargs) def call(self, x,alpha=1.0): 这是梯度反转层的初始化函数 :param x: 输入张量 :param alpha: alpha系数,默认为1 :return: return gradient_reversal(x,alpha)
在上述代码中@ops.RegisterGradient(grad_name)修饰 _flip_gradients(op, grad)函数,即自定义该层的梯度取反。同时gradient_override_map函数主要用于解决使用自己定义的函数方式来求梯度的问题,gradient_override_map函数的参数值为一个字典。即字典中value表示使用该值表示的函数代替key表示的函数进行梯度运算。
DANN论文 Unsupervised Domain Adaptation by Backpropagation 中给出MNIST和MNIST-M数据集的迁移训练实验的网络,网络架构图如下图所示。 接下来,我们将利用tensorflow2.4.0来搭建整个DANN-MNIST网络,DANN-MNIST网络结构代码如下:
# -*- coding: utf-8 -*- # @Time : 2020/2/14 20:27 # @Author : Dai PuWei # @Email : 771830171@qq.com # @File : MNIST2MNIST_M.py # @Software: PyCharm import tensorflow as tf from tensorflow.keras.layers import Dense from tensorflow.keras.layers import Conv2D from tensorflow.keras.layers import Flatten from tensorflow.keras.layers import MaxPool2D from tensorflow.keras.layers import Activation def build_feature_extractor(): 这是特征提取子网络的构建函数 :param image_input: 图像输入张量 :param name: 输出特征名称 :return: model = tf.keras.Sequential([Conv2D(filters=32, kernel_size=5,strides=1), #tf.keras.layers.BatchNormalization(), Activation('relu'), MaxPool2D(pool_size=(2, 2), strides=2), Conv2D(filters=48, kernel_size=5,strides=1), #tf.keras.layers.BatchNormalization(), Activation('relu'), MaxPool2D(pool_size=(2, 2), strides=2), Flatten(), return model def build_image_classify_extractor(): 这是搭建图像分类器模型的函数 :param image_classify_feature: 图像分类特征张量 :return: model = tf.keras.Sequential([Dense(100), #tf.keras.layers.BatchNormalization(), Activation('relu'), #tf.keras.layers.Dropout(0.5), Dense(100,activation='relu'), #tf.keras.layers.Dropout(0.5), Dense(10,activation='softmax',name="image_cls_pred"), return model def build_domain_classify_extractor(): 这是搭建域分类器的函数 :param domain_classify_feature: 域分类特征张量 :return: # 搭建域分类器 model = tf.keras.Sequential([Dense(100),