代码 原理部分,移步之前的文章 人工智能导论 。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 import torchfrom torch.utils.data import DataLoaderfrom torchvision import transformsfrom torchvision.datasets import MNISTimport matplotlib.pyplot as pltclass Net (torch.nn.Module): def __init__ (self ): super ().__init__() self.fc1 = torch.nn.Linear(28 *28 , 64 ) self.fc2 = torch.nn.Linear(64 , 64 ) self.fc3 = torch.nn.Linear(64 , 32 ) self.fc4 = torch.nn.Linear(32 , 32 ) self.fc5 = torch.nn.Linear(32 , 10 ) def forward (self, x ): x = torch.nn.functional.relu(self.fc1(x)) x = torch.nn.functional.relu(self.fc2(x)) x = torch.nn.functional.relu(self.fc3(x)) x = torch.nn.functional.relu(self.fc4(x)) x = torch.nn.functional.log_softmax(self.fc5(x), dim=1 ) return x def get_data_loader (is_train ): to_tensor = transforms.Compose([transforms.ToTensor()]) data_set = MNIST("" , is_train, transform=to_tensor, download=True ) return DataLoader(data_set, batch_size=15 , shuffle=True ) def evaluate (test_data, net ): n_correct = 0 n_total = 0 with torch.no_grad(): for (x, y) in test_data: outputs = net.forward(x.view(-1 , 28 *28 )) for i, output in enumerate (outputs): if torch.argmax(output) == y[i]: n_correct += 1 n_total += 1 return n_correct / n_total def main (): train_data = get_data_loader(is_train=True ) test_data = get_data_loader(is_train=False ) net = Net() print ("initial accuracy:" , evaluate(test_data, net)) optimizer = torch.optim.Adam(net.parameters(), lr=0.001 ) for epoch in range (2 ): for (x, y) in train_data: net.zero_grad() output = net.forward(x.view(-1 , 28 *28 )) loss = torch.nn.functional.nll_loss(output, y) loss.backward() optimizer.step() print ("epoch" , epoch, "accuracy:" , evaluate(test_data, net)) for (n, (x, _)) in enumerate (test_data): if n > 3 : break predict = torch.argmax(net.forward(x[0 ].view(-1 , 28 *28 ))) plt.figure(n) plt.imshow(x[0 ].view(28 , 28 )) plt.title("prediction: " + str (int (predict))) plt.show() if __name__ == "__main__" : main()
代码解读 以下 AIGC 。
nn.Linear torch.nn.Linear
是 PyTorch 中用于创建全连接层(也称为线性层或仿射层)的类。这个层的主要功能是将输入数据与一个权重矩阵相乘,并加上一个偏置向量。它可以用来构建神经网络的全连接层。
torch.nn.Linear
返回的是一个线性层对象,它本质上是一个可调用的对象 (即可以像函数一样调用它)。将一个输入张量传递给这个线性层时,它会对输入进行线性变换,并返回一个新的张量。
示例:
1 2 3 4 5 6 7 8 9 10 11 12 13 import torchimport torch.nn as nnlinear = nn.Linear(28 *28 , 64 ) x = torch.randn(10 , 28 *28 ) output = linear(x) print (output.shape)
log_softmax log_softmax
对输出归一化 。它是 softmax
函数的对数版本,通常用于多分类任务的神经网络输出层。
softmax
将一个实数向量转换为概率分布,输出的每个值在 0 到 1 之间,并且和为 1 。公式为:
其中 $x_i$ 是输入向量的第 $i$ 个元素,$\sum_{j} e^{x_j}$ 是所有输入值的指数和。
log_softmax
是 softmax
函数的对数形式,公式为:
先 softmax
再取对数,容易出现数值不稳定的问题,尤其是当数值非常小时,可能导致数值下溢。而 log_softmax
在内部将这两个操作结合起来计算,避免了这种数值不稳定。
在多分类任务中,交叉熵损失函数常常与 log_softmax
一起使用。具体来说,PyTorch 中的 torch.nn.functional.nll_loss
(负对数似然损失)要求输入的是对数概率,而不是直接的概率。log_softmax
的输出正好符合 nll_loss
的输入要求。
nll_loss nll_loss
用于计算模型输出的对数概率与目标标签之间的差异。它的核心思想是通过最小化负对数似然来使模型的预测与真实标签更接近。
在代码中,输出经过 log_softmax
,给出每个类别的对数概率,nll_loss
则计算这些对数概率与真实标签的匹配程度。通过最小化这个损失,模型能够逐渐学会正确分类。
假设模型的输出是一个概率分布,nll_loss
对应的公式为:
其中:
$ N $ 是样本的数量,
$ p(y_i) $ 是模型对样本 $ i $ 的正确类别 $ y_i $ 的预测概率(经过 log_softmax
后已经是对数概率),
$ L $ 是最终的平均损失值。
get_data_loader get_data_loader
函数用于加载 MNIST 数据集,并返回一个 PyTorch 的 DataLoader
对象。DataLoader
是 PyTorch 中的一个重要组件,用于批量化处理数据集,以便在训练或测试时高效地加载和使用数据。
1 to_tensor = transforms.Compose([transforms.ToTensor()])
transforms
是 torchvision
提供的工具,用于对图像进行预处理。transforms.Compose
是一个将多个变换组合在一起的函数,这里只用了一个变换 transforms.ToTensor()
。
transforms.ToTensor()
将 PIL 图像或 NumPy 数组转换为 PyTorch 的张量(tensor),并将像素值缩放到 [0, 1]
之间。MNIST 数据集中原始像素值是 0
到 255
,而 ToTensor()
会自动将其归一化。
1 return DataLoader(data_set, batch_size=15 , shuffle=True )
DataLoader
是 PyTorch 中用于处理和批量化数据集的工具。它会将 data_set
(MNIST 数据集)分批次加载,每次返回指定数量的数据。具体参数解释如下:
batch_size=15
:指定每个批次包含 15 张图片。在训练神经网络时,通常不使用整个数据集,而是将数据集分成多个批次(batch),在每个批次上执行前向传播和反向传播。
shuffle=True
:表示每个 epoch 开始时,打乱数据集。打乱数据可以提高模型训练的随机性,防止模型过拟合于数据的特定顺序。
x.view(-1, 28*28) 1 outputs = net.forward(x.view(-1 , 28 *28 ))
x.view(-1, 28*28)
这部分将输入数据 x
进行重塑(reshape)。MNIST 数据集中的每张图片原本是 28x28 的二维图像张量,但全连接层要求输入的一维张量。因此需要将图片从 28x28 展平为一个一维的 784 维向量(28 * 28 = 784
)。
x
是一个四维张量,形状为 (batch_size, 1, 28, 28)
,其中 batch_size
是当前批次的大小。
view()
是 PyTorch 中的一个张量重塑函数。x.view(-1, 28*28)
的作用是将 x
重塑为形状 (batch_size, 28*28)
的二维张量。
-1
表示自动推断维度,PyTorch 会根据其他维度的大小来推断 batch_size
,即这个维度的大小保持不变。
例如,如果 x
的形状为 (15, 1, 28, 28)
,表示批次大小为 15,每张图片大小为 28x28,则 x.view(-1, 28*28)
会将 x
转换为 (15, 784)
的二维张量。
enumerate() enumerate()
是 Python 内置函数,它允许在循环中同时获得索引 和元素 。对于 outputs
来说,enumerate(outputs)
会返回每个样本的索引和对应的输出值。
i
是当前迭代的索引,表示第 i
个样本。
output
是 outputs
中第 i
个样本的输出,即一个 10 维的张量,包含了该样本对 10 个类别的预测概率。
optimizer 1 optimizer = torch.optim.Adam(net.parameters(), lr=0.001 )
torch.optim.Adam
是 PyTorch 中实现的 Adam 优化器。Adam(Adaptive Moment Estimation)是一种常用的优化算法,结合了动量法 和RMSProp 优化器的优点,它通过自适应地调整学习率来加快训练速度,且在处理稀疏梯度问题时表现很好。
net.parameters()
是一个函数,返回神经网络模型 Net
的所有可训练参数。每个神经网络层(如线性层 fc1
, fc2
等)都会包含其自己的参数(权重和偏置),这些参数会随着训练过程逐渐优化。
通过 net.parameters()
,优化器能够访问和更新这些参数。
lr
是学习率(learning rate)的缩写,表示每次参数更新的步长。在梯度下降过程中,学习率决定了模型的权重如何调整:
较小的学习率 (如 0.0001
)会导致训练速度变慢,但更精确。
较大的学习率 (如 0.1
)会导致训练速度加快,但可能不稳定,甚至无法收敛。
在这个例子中,学习率设为 0.001
,是一个常见的选择。Adam 优化器对学习率相对不那么敏感,它能根据数据的特性自适应地调整每个参数的学习率,所以通常这个值不需要调得太精细。
实用拓展 现在我们想要实现这样的功能: 由用户绘制一个数字,利用刚刚训练出的模型,识别这是什么数字。
这个功能在逻辑上并不困难,但是一些代码细节会稍显繁琐。
比如,我们采样鼠标左键,但由于不可避免的时间间隔,会画出一些离散的点。但模型是采用连续笔画的图像训练的,这些离散的点无法得到正确的识别结果。因此,需要对这些点插值。但这带来了进一步的问题,有些数字,例如 4,有两个笔画,我们不能在第一笔末尾的点、第二笔开始的点之间插值。简单的解决方案是,设置一个时间阈值。
再比如,笔画的粗细可能会对识别结果有影响。当然,这可以说是模型本身的问题,但如果你不想再死磕这个模型,可以简单地调整绘画窗口中笔画的粗细。
下面放出代码。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 import torchfrom torch.utils.data import DataLoaderfrom torchvision import transformsfrom torchvision.datasets import MNISTimport matplotlib.pyplot as pltclass Net (torch.nn.Module): def __init__ (self ): super ().__init__() self.fc1 = torch.nn.Linear(28 *28 , 64 ) self.fc2 = torch.nn.Linear(64 , 64 ) self.fc3 = torch.nn.Linear(64 , 64 ) self.fc4 = torch.nn.Linear(64 , 10 ) def forward (self, x ): x = torch.nn.functional.relu(self.fc1(x)) x = torch.nn.functional.relu(self.fc2(x)) x = torch.nn.functional.relu(self.fc3(x)) x = torch.nn.functional.log_softmax(self.fc4(x), dim=1 ) return x def get_data_loader (is_train ): to_tensor = transforms.Compose([transforms.ToTensor()]) data_set = MNIST("" , is_train, transform=to_tensor, download=True ) return DataLoader(data_set, batch_size=15 , shuffle=True ) def evaluate (test_data, net ): n_correct = 0 n_total = 0 with torch.no_grad(): for (x, y) in test_data: outputs = net.forward(x.view(-1 , 28 *28 )) for i, output in enumerate (outputs): if torch.argmax(output) == y[i]: n_correct += 1 n_total += 1 return n_correct / n_total def main (): train_data = get_data_loader(is_train=True ) test_data = get_data_loader(is_train=False ) net = Net() print ("initial accuracy:" , evaluate(test_data, net)) optimizer = torch.optim.Adam(net.parameters(), lr=0.001 ) for epoch in range (4 ): for (x, y) in train_data: net.zero_grad() output = net.forward(x.view(-1 , 28 *28 )) loss = torch.nn.functional.nll_loss(output, y) loss.backward() optimizer.step() print ("epoch" , epoch, "accuracy:" , evaluate(test_data, net)) for (n, (x, _)) in enumerate (test_data): if n > 3 : break predict = torch.argmax(net.forward(x[0 ].view(-1 , 28 *28 ))) plt.figure(n) plt.imshow(x[0 ].view(28 , 28 )) plt.title("prediction: " + str (int (predict))) plt.show() torch.save(net.state_dict(), 'handwrite_model.pth' ) print ("模型已保存为 handwrite_model.pth" ) if __name__ == "__main__" : main()
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 import torchimport timeimport numpy as npfrom tkinter import *from PIL import Image, ImageOpsclass Net (torch.nn.Module): def __init__ (self ): super ().__init__() self.fc1 = torch.nn.Linear(28 *28 , 64 ) self.fc2 = torch.nn.Linear(64 , 64 ) self.fc3 = torch.nn.Linear(64 , 64 ) self.fc4 = torch.nn.Linear(64 , 10 ) def forward (self, x ): x = torch.nn.functional.relu(self.fc1(x)) x = torch.nn.functional.relu(self.fc2(x)) x = torch.nn.functional.relu(self.fc3(x)) x = torch.nn.functional.log_softmax(self.fc4(x), dim=1 ) return x class HandwrittenDigitApp : def __init__ (self, model ): self.model = model self.root = Tk() self.canvas = Canvas(self.root, width=280 , height=280 , bg="white" ) self.canvas.pack() self.canvas.bind("<B1-Motion>" , self.paint) Button(self.root, text="Predict" , command=self.predict).pack() Button(self.root, text="Clear" , command=self.clear_canvas).pack() self.last_x, self.last_y = None , None self.last_time = None def paint (self, event ): x, y = event.x, event.y current_time = time.time() if self.last_x is not None and self.last_y is not None : time_interval = current_time - self.last_time if time_interval < 0.3 : self.interpolate_line(self.last_x, self.last_y, x, y) self.canvas.create_oval(x, y, x+6 , y+6 , fill='black' ) self.last_x, self.last_y = x, y self.last_time = current_time def interpolate_line (self, x1, y1, x2, y2 ): distance = max (abs (x2 - x1), abs (y2 - y1)) for i in range (1 , distance): xi = x1 + (x2 - x1) * i / distance yi = y1 + (y2 - y1) * i / distance self.canvas.create_oval(xi, yi, xi+6 , yi+6 , fill='black' ) def clear_canvas (self ): self.canvas.delete("all" ) self.last_x, self.last_y = None , None self.last_time = None def predict (self ): self.canvas.postscript(file="digit.ps" ) img = Image.open ("digit.ps" ).convert("L" ) img.save("digit.png" ) img = img.resize((28 , 28 )) img = ImageOps.invert(img) img_tensor = torch.tensor(np.array(img)).float ().view(-1 , 28 *28 ) output = self.model(img_tensor) prediction = torch.argmax(output) print ("Predicted digit:" , prediction.item()) self.show_prediction(prediction.item()) def show_prediction (self, prediction ): result_window = Toplevel(self.root) result_window.title("Prediction" ) Label(result_window, text=f"Predicted digit: {prediction} " , font=("Helvetica" , 24 )).pack() def run (self ): self.root.mainloop() def main (): net = Net() net.load_state_dict(torch.load('handwrite_model.pth' , weights_only=True )) app = HandwrittenDigitApp(net) app.run() if __name__ == "__main__" : main()