适合刚入门 PyTorch、想快速跑通完整机器学习 pipeline的同学全文结构数据准备 → 构建模型 → 训练优化 → 模型评估 → 保存加载 → 单张预测一、前言这篇笔记把 PyTorch 官方最快入门案例完整拆解一行代码一个知识点帮你快速掌握数据集怎么加载模型怎么定义训练循环怎么写模型怎么保存与推理任务FashionMNIST 服饰图片分类10 分类二、环境与依赖importtorchfromtorchimportnnfromtorch.utils.dataimportDataLoaderfromtorchvisionimportdatasetsfromtorchvision.transformsimportToTensor三、数据准备Dataset DataLoader1. 下载数据集# 训练集training_datadatasets.FashionMNIST(rootdata,trainTrue,downloadTrue,transformToTensor(),)# 测试集test_datadatasets.FashionMNIST(rootdata,trainFalse,downloadTrue,transformToTensor(),)2. 构建 DataLoaderbatch_size64train_dataloaderDataLoader(training_data,batch_sizebatch_size)test_dataloaderDataLoader(test_data,batch_sizebatch_size)# 查看数据形状forX,yintest_dataloader:print(fShape of X [N, C, H, W]:{X.shape})print(fShape of y:{y.shape}{y.dtype})break输出Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28]) Shape of y: torch.Size([64]) torch.int64四、构建模型继承 nn.Module1. 选择设备devicetorch.accelerator.current_accelerator().typeiftorch.accelerator.is_available()elsecpuprint(fUsing{device}device)2. 定义网络结构classNeuralNetwork(nn.Module):def__init__(self):super().__init__()self.flattennn.Flatten()self.linear_relu_stacknn.Sequential(nn.Linear(28*28,512),nn.ReLU(),nn.Linear(512,512),nn.ReLU(),nn.Linear(512,10))defforward(self,x):xself.flatten(x)logitsself.linear_relu_stack(x)returnlogits modelNeuralNetwork().to(device)print(model)五、训练相关配置损失函数 优化器loss_fnnn.CrossEntropyLoss()optimizertorch.optim.SGD(model.parameters(),lr1e-3)六、训练函数 测试函数1. 训练函数deftrain(dataloader,model,loss_fn,optimizer):sizelen(dataloader.dataset)model.train()forbatch,(X,y)inenumerate(dataloader):X,yX.to(device),y.to(device)# 前向传播predmodel(X)lossloss_fn(pred,y)# 反向传播 更新loss.backward()optimizer.step()optimizer.zero_grad()ifbatch%1000:loss,currentloss.item(),(batch1)*len(X)print(floss:{loss:7f}[{current:5d}/{size:5d}])2. 测试函数deftest(dataloader,model,loss_fn):sizelen(dataloader.dataset)num_batcheslen(dataloader)model.eval()test_loss,correct0,0withtorch.no_grad():forX,yindataloader:X,yX.to(device),y.to(device)predmodel(X)test_lossloss_fn(pred,y).item()correct(pred.argmax(1)y).type(torch.float).sum().item()test_loss/num_batches correct/sizeprint(fTest Error:)print(f Accuracy:{(100*correct):0.1f}%, Avg loss:{test_loss:8f}\n)七、开始训练epochs5fortinrange(epochs):print(fEpoch{t1}\n-------------------------------)train(train_dataloader,model,loss_fn,optimizer)test(test_dataloader,model,loss_fn)print(Done!)八、模型保存与加载1. 保存模型torch.save(model.state_dict(),model.pth)print(Saved PyTorch Model State to model.pth)2. 加载模型modelNeuralNetwork().to(device)model.load_state_dict(torch.load(model.pth,weights_onlyTrue))九、单张图片推理classes[T-shirt/top,Trouser,Pullover,Dress,Coat,Sandal,Shirt,Sneaker,Bag,Ankle boot,]model.eval()x,ytest_data[0][0],test_data[0][1]withtorch.no_grad():xx.to(device)predmodel(x)predicted,actualclasses[pred[0].argmax(0)],classes[y]print(fPredicted: {predicted}, Actual: {actual})十、整体流程总结必背Dataset DataLoader搞定数据class 继承 nn.Module定义模型loss_fn optimizer配置训练train 函数前向 → loss → 反向 → 更新 → 清零test 函数eval() no_grad()save/load完成模型持久化eval() no_grad()做推理