完整代码下载【github地址】:
github.com/lmn-ning/MN…
一、MNIST数据集介绍及下载地址
MNIST手写数字识别可以说是机器学习入门的hello word了, MNIST数据集包含70000张手写数字图像:,其中60000张用于训练,10000张用于测试。
官网下载地址:
yann.lecun.com/exdb/mnist/
MNIST数据集共有四个文件:
train-images-idx3-ubyte.gz:训练集图片,60000张。
train-labels-idx1-ubyte.gz:训练集图片对应的标签。
t10k-images-idx3-ubyte .gz:测试集图片,10000张。
t10k-labels-idx1-ubyte.gz:测试集图片对应的标签。
图片是0〜9的手写数字图片,共10类,标签是图片的实际数字。每张图片都是28x28的单通道灰度图,且数字居中以减少预处理和加快运行。
可以自己下载数据集并加载迭代器,也可以使用torchvision自带的下载函数进行在线下载。
二、代码结构
data:torchvision在线下载的数据集(程序用的是这个方法)
MNIST:自己下载的数据集(在程序里面没有用到)
save_model:用于存放保存的模型参数的pt文件
dataset.py:数据处理脚本
cnn.py:网络模型脚本
train:训练和测试脚本
eval:验证脚本
dataset.py
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms as tsf
import cv2
batch_size = 64
transform = tsf.Compose([tsf.ToTensor(), tsf.Normalize([0.1307], [0.3081])])
train_set=datasets.MNIST(root="data",train=True,download=True,transform=transform)
test_set=datasets.MNIST(root="data",train=False,download=True,transform=transform)
def get_data_loader():
train_loader=DataLoader(dataset=train_set,batch_size=batch_size,shuffle=True)
test_loader=DataLoader(dataset=test_set,batch_size=batch_size,shuffle=True)
return train_loader,test_loader
cnn.py
import torch.nn as nn
import torch.nn.functional as F
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=10, kernel_size=(5,5))
self.conv2 = nn.Conv2d(10, 20, kernel_size=(3,3))
self.fc1 = nn.Linear(in_features=20*10*10, out_features=500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
input_size=x.size(0)
x=self.conv1(x)
x=F.relu(x)
x=F.max_pool2d(x,kernel_size=2,stride=2)
x=self.conv2(x)
x=F.relu(x)
x=x.view(input_size,-1)
x=self.fc1(x)
x=F.relu(x)
x=self.fc2(x)
output=F.log_softmax(x,dim=1)
return output
train.py
import torch
import cnn
import torch.nn.functional as F
from dataset import get_data_loader
import torch.optim as optim
if __name__ == "__main__":
batch_size=64
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
epoch = 5
model=cnn.CNN().to(device)
optimizer=optim.Adam(model.parameters())
train_loader, test_loader = get_data_loader()
def train(epoch_i):
model.train()
for batch_i,(digit,label) in enumerate(train_loader):
digit,label=digit.to(device),label.to(device)
optimizer.zero_grad()
output=model(digit)
loss=F.cross_entropy(output,label)
loss.backward()
optimizer.step()
if batch_i % 100 == 0:
print("train epoch_i: {} batch_i: {} loss: {: .8f}".format(epoch_i,batch_i,loss.item()))
def test(epoch_i):
model.eval()
acc = 0.
loss = 0.
with torch.no_grad():
for digit, label in test_loader:
digit, lable = digit.to(device), label.to(device)
output = model(digit)
loss += F.cross_entropy(output, lable).item()
predict = output.max(dim=1, keepdim=True)[1]
acc += predict.eq(label.view_as(predict)).sum().item()
accuracy = acc / len(test_loader.dataset) * 100
test_loss = loss / len(test_loader.dataset)
print("test epoch_i: {} loss: {: .8f} accuracy: {: .4f}%".format(epoch_i,test_loss,accuracy))
for epoch_i in range(1,epoch+1):
train(epoch_i)
test(epoch_i)
torch.save(model,"save_model/model.pt")
eval.py
import torch
from dataset import get_data_loader
if __name__ == "__main__":
_, eval_loader = get_data_loader()
batch_size = 64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load("save_model/model.pt")
acc = 0.
with torch.no_grad():
for digit, label in eval_loader:
digit, lable = digit.to(device), label.to(device)
output = model(digit)
predict = output.max(dim=1, keepdim=True)[1]
acc += predict.eq(label.view_as(predict)).sum().item()
acceracy = acc/len(eval_loader.dataset) * 100
print("eval accuracy: {: .4f}%".format(acceracy))
四、代码运行命令及测试精度
训练:python train.py
验证:python eval.py
训练精度:
test epoch_i: 1 loss: 0.00083338 accuracy: 98.2700%
test epoch_i: 2 loss: 0.00053913 accuracy: 98.8400%
test epoch_i: 3 loss: 0.00059983 accuracy: 98.8800%
test epoch_i: 4 loss: 0.00063796 accuracy: 98.7200%
test epoch_i: 5 loss: 0.00056076 accuracy: 98.8800%
验证精度:
eval accuracy: 98.8800%
因为MNIST数据集没有单独的验证集,所以将测试集当做验证集使用。保存的模型是最后一个epoch的模型。因此eval accuracy与最后一个epoch的test accuracy相同。
这个程序只是为了熟悉pytorch深度学习的基本流程,训练集的选择和精度并不重要。