欢迎关注”生信修炼手册”!
本文是对tensforflow官方入门教程的学习和翻译,展示了创建一个基础的卷积神经网络模型来解决图像分类问题的过程。具体步骤如下
1. 加载数据集
tensorflow集成了keras这个框架,提供了CIFAR10数据集,该数据集包含了10个类别共6万张彩色图片,加载方式如下
>>> import tensorflow as tf >>> from tensorflow.keras import datasets,layers, models >>> import matplotlib.pyplot as plt >>> (train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data() Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 170500096/170498071 [==============================] - 33s 0us/step >>> train_images, test_images = train_images / 255.0, test_images / 255.0
可以通过如下代码来查看部分图片
>>> for i in range(25): ... plt.subplot(5, 5, i + 1) ... plt.xticks([]) ... plt.yticks([]) ... plt.grid(False) ... plt.imshow(train_images[i], cmap = plt.cm.binary) ... plt.xlabel(class_names[train_labels[i][0]]) >>> plt.show()
可视化效果如下
2. 构建卷积神经网络
通过keras的Sequential API来构建卷积神经网络,依次添加卷积层,池化层,全连接层,代码如下
>>> model = models.Sequential() >>> model.add(layers.Conv2D(32, (3, 3), activation = "relu", input_shape = (32, 32, 3))) >>> model.add(layers.MaxPooling2D((2, 2))) >>> model.add(layers.Conv2D(64, (3,3), activation = "relu")) >>> model.add(layers.MaxPooling2D((2, 2))) >>> model.add(layers.Conv2D(64, (3, 3), activation = "relu")) >>> model.add(layers.Flatten()) >>> model.add(layers.Dense(64, activation = "relu")) >>> model.add(layers.Dense(10)) >>> model.summary() Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d_1 (Conv2D) (None, 30, 30, 32) 896 _________________________________________________________________ max_pooling2d (MaxPooling2D) (None, 15, 15, 32) 0 _________________________________________________________________ conv2d_2 (Conv2D) (None, 13, 13, 64) 18496 _________________________________________________________________ max_pooling2d_1 (MaxPooling2 (None, 6, 6, 64) 0 _________________________________________________________________ conv2d_4 (Conv2D) (None, 4, 4, 64) 36928 _________________________________________________________________ flatten (Flatten) (None, 1024) 0 _________________________________________________________________ dense_1 (Dense) (None, 64) 65600 _________________________________________________________________ dense_2 (Dense) (None, 10) 650 ================================================================= Total params: 122,570 Trainable params: 122,570 Non-trainable params: 0 _________________________________________________________________
3. 编译模型
模型在训练之前,必须对其进行编译,主要是确定损失函数,优化器以及评估分类效果好坏的指标,代码如下
>>> model.compile(optimizer = 'adam', loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics = ['accuracy'])
4. 训练模型
使用训练集训练模型,代码如下
>>> history = model.fit(train_images, train_labels, epochs = 10, validation_data = (test_images, test_labels)) 2021-06-23 10:59:43.386592: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2) Epoch 1/10 1563/1563 [==============================] - 412s 203ms/step - loss: 1.5396 - accuracy: 0.4380 - val_loss: 1.2760 - val_accuracy: 0.5413 Epoch 2/10 1563/1563 [==============================] - 94s 60ms/step - loss: 1.1637 - accuracy: 0.5850 - val_loss: 1.1193 - val_accuracy: 0.6084 Epoch 3/10 1563/1563 [==============================] - 95s 61ms/step - loss: 1.0210 - accuracy: 0.6398 - val_loss: 0.9900 - val_accuracy: 0.6556 Epoch 4/10 1563/1563 [==============================] - 88s 56ms/step - loss: 0.9186 - accuracy: 0.6781 - val_loss: 0.9399 - val_accuracy: 0.6687 Epoch 5/10 1563/1563 [==============================] - 95s 61ms/step - loss: 0.8472 - accuracy: 0.7023 - val_loss: 0.8984 - val_accuracy: 0.6868 Epoch 6/10 1563/1563 [==============================] - 85s 55ms/step - loss: 0.7917 - accuracy: 0.7220 - val_loss: 0.8896 - val_accuracy: 0.6888 Epoch 7/10 1563/1563 [==============================] - 88s 56ms/step - loss: 0.7450 - accuracy: 0.7381 - val_loss: 0.8843 - val_accuracy: 0.6974 Epoch 8/10 1563/1563 [==============================] - 87s 55ms/step - loss: 0.7024 - accuracy: 0.7530 - val_loss: 0.8403 - val_accuracy: 0.7089 Epoch 9/10 1563/1563 [==============================] - 92s 59ms/step - loss: 0.6600 - accuracy: 0.7676 - val_loss: 0.8512 - val_accuracy: 0.7095 Epoch 10/10 1563/1563 [==============================] - 91s 58ms/step - loss: 0.6240 - accuracy: 0.7790 - val_loss: 0.8483 - val_accuracy: 0.7119
通过比较训练集和验证集的准确率曲线,可以判断模型训练是否有过拟合等问题,代码如下
>>> plt.plot(history.history['accuracy'], label='accuracy') [<matplotlib.lines.Line2D object at 0x000001AAC62A7B08>] >>> plt.plot(history.history['val_accuracy'], label = 'val_accuracy') [<matplotlib.lines.Line2D object at 0x000001AAC28F8988>] >>> plt.xlabel('Epoch') Text(0.5, 0, 'Epoch') >>> plt.ylabel('Accuracy') Text(0, 0.5, 'Accuracy') >>> plt.ylim([0.5, 1]) (0.5, 1.0) >>> plt.legend(loc='lower right') <matplotlib.legend.Legend object at 0x000001AAC62A7688>