基于PyTorch图像分类的项目复现总结
基于PyTorch图像分类的项目复现总结1. 项目背景1.1 项目简介本项目是一个基于PyTorch的深度学习图像分类训练框架旨在为研究人员和开发者提供一个完整、灵活且高效的图像分类模型训练解决方案。框架集成了多种经典的卷积神经网络架构支持多个主流图像数据集并实现了混合精度训练、异步数据加载、分布式训练等核心技术优化。1.2 项目目标本项目一提升个人能力为目标便于以后自己在写项目的时候可以借鉴该项目中使用的一些方法1.3 技术栈深度学习框架: PyTorch数据处理: torchvision、NumPy可视化: TensorBoard、Matplotlib、Seaborn评估工具: scikit-learn配置文件: YAML2. 数据集说明2.1 数据集来源本框架支持以下主流图像分类数据集均通过PyTorch官方 torchvision 库自动下载数据集来源用途MNISTLeCun等人创建的手写数字数据库入门级图像分类FashionMNISTZalando服装图片数据集复杂度的入门级任务CIFAR-10加拿大高级研究院收集的32×32彩色图像中级图像分类任务CIFAR-100CIFAR-10的扩展100个类别细粒度图像分类ImageNet大规模视觉识别挑战赛(ILSVRC)数据集大规模图像分类与预训练2.2 数据规模数据集训练集测试集类别数图像尺寸通道数MNIST60,00010,0001028×28灰度(1)FashionMNIST60,00010,0001028×28灰度(1)CIFAR-1050,00010,0001032×32RGB(3)CIFAR-10050,00010,00010032×32RGB(3)ImageNet~1,280,00050,0001000可变RGB(3)2.3 数据集结构项目采用以下目录结构管理数据data/ ├── mnist/ │ └── MNIST/ │ ├── raw/ │ │ ├── train-images-idx3-ubyte.gz │ │ ├── train-labels-idx1-ubyte.gz │ │ ├── t10k-images-idx3-ubyte.gz │ │ └── t10k-labels-idx1-ubyte.gz │ └── processed/ │ ├── training.pt │ └── test.pt ├── cifar10/ │ └── cifar-10-batches-py/ │ ├── data_batch_1~5 │ ├── test_batch │ └── batches.meta ├── cifar100/ │ └── cifar-100-python/ │ ├── train │ ├── test │ └── meta └── imagenet/ ├── train/ │ └── [class_folder]/ │ └── [image].JPEG └── val/ └── [class_folder]/ └── [image].JPEG2.4 数据预处理框架为每个数据集定义了标准化的预处理流程MNIST / FashionMNIST:transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) # MNIST均值和标准差 ])CIFAR-10:transforms.Compose([ transforms.RandomCrop(32, padding4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ])CIFAR-100:(增强版)transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding4), transforms.ColorJitter(brightness0.3, contrast0.2, saturation0.4, hue0.4), transforms.RandomAffine(degrees0.3, translate(0.1, 0.1), scale(0.5, 1.25), shear5), transforms.RandomGrayscale(), transforms.ToTensor(), transforms.Normalize(mean, std) ])3. 核心技术概述3.1 三大核心技术技术说明作用混合精度训练 (AMP)使用torch.cuda.amp.GradScaler和autocast加速训练、减少显存占用数据预加载 (DataPrefetcher)使用 CUDA Stream 异步预加载数据减少数据等待时间分布式训练支持多GPU并行 (NCCL后端)加速大规模训练3.2 学习率策略对比策略名称公式/特点适用场景warmup线性增长lr base_lr * (epoch1) / warmup_length防止初期训练不稳定cosine_lrlr 0.5 * (1 cos(π * e / es)) * base_lr平滑衰减后期收敛平稳normal_lr高斯衰减lr base_lr * exp(-(e-init)² / scale)中期快速下降multistep_lr阶梯式衰减在指定epoch处乘以gamma简单有效广泛使用constant_lr保持不变调试/基线实验4. 各模块文件详解4.1datasets.py —— 数据加载与预处理作用加载多种图像数据集实现数据增强和预处理核心函数def get_data(args) ​ 对数据集的处理方法 1.对每一种数据集设置专门的transform 用用于数据变换train和test可能会设置不同的情况 ​ 2.下载数据集 train_dataset datasets.MNIST() train_dataset datasets.CIFAR10() train_dataset datasets.CIFAR100() train_dataset datasets.FashionMNIST() ImageNet 要自己下载 datasets.ImageFolder() 下载好之后要用这个方法去加载 ​ ​ ​ def loaders(train_dataset, validation_dataset, args): 1.确定内存格式 # channels_last(更高效)或contigous,因为channels_last更适合卷积运算训练速度提升显著 if args.channels_last: memory_format torch.channels_last else: memory_format torch.contiguous_format 2.创建数据加载器 train_loader torch.utils.data.DataLoader( train_dataset, batch_size shuffle num_workers 数据加载线程数 pin_memory 锁页内存加速数据传输到GPU )数据增强方法方法说明RandomCrop随机裁剪RandomHorizontalFlip随机水平翻转ColorJitter随机调整颜色/亮度/对比度RandomAffine随机仿射变换RandomGrayscale随机转灰度图Normalize标准化ToTensor转为Tensor内存格式选择if args.channels_last: memory_format torch.channels_last # 推荐卷积运算更高效 else: memory_format torch.contiguous_format4.2 模型文件 (models/bare)张量展平方式对比# 假设 x.shape (batch, channels, height, width) (4, 64, 7, 7) # 结果相同: (4, 3136) ​ x torch.flatten(x, 1) # 从第1维开始展平 x x.view(x.size(0), -1) # 保留batch维展平其余 x x.reshape(x.shape[0], -1) # 自动计算维度卷积层参数参数说明in_planes输入通道数out_planes输出通道数kernel_size卷积核大小 (默认3)stride步长 (默认1)groups分组卷积 (默认1)dilation膨胀率 (默认1)bias偏置 (默认True)分类器构建示例self.classifier nn.Sequential( nn.Dropout(0.2), # 随机丢弃20%神经元防止过拟合 nn.Linear(last_channel, num_classes) )4.3 arguments.py —— 参数管理功能解析命令行参数读取YAML配置文件命令行参数覆盖配置文件#创建命令行参数解析器 ​ def parse(): #创建参数解析器 parser argparse.ArgumentParser( descriptionComprehensive image classification ) ​ #创建参数配置 #参1命令行参数名 参2短参数名 参3帮助信息中显示的参数名用于多值参数时 #参4参数类型 参5参数描述 parser.add_argument(--config,-c,metavarC,defaultminist_fc_train,typestr,helpyaml配置文件名称) ​ ... ... ... ​ ​ ​ # 解析参数 - 解析命令行传入的参数如 --arch mobilenet_v2 - 生成 args 对象 args parser.parse_args() ​ # 读取YAML配置文件 # 路径格式configs / 数据集名 / 模型名 / 配置名.yaml with open(f./configs/{args.dataname}/{args.arch}/{args.config}.yaml, r) as f: yaml_dict yaml.load(f, Loaderyaml.FullLoader) ​ # 命令行参数覆盖配置文件 for arg in vars(args): if -- arg.replace(_, -) in sys.argv: yaml_dict[arg] getattr(args, arg) vars(args) | 返回 args 的所有属性字典 | | arg.replace(_, -) | 将下划线转为横线如 lr_steps → lr-steps | | -- arg | 构造完整参数名如 --lr-steps | | sys.argv | 命令行参数列表 | | getattr(args, arg) | 获取 args 对象中该属性的值 | 作用如果命令行指定了某参数就用命令行值覆盖 YAML 配置 # 更新args对象 update() 合并字典 args.__dict__.update(yaml_dict) ​ return args4.4 train_validation.py —— 训练与验证训练流程 (train方法)#模型训练步骤 def train(self,epoch): #epoch当前epoch编号 #1.初始化需要用到的参数 self.epoch epoch self.epoch_length len(self.train_loader) #2.创建指标计算器 self.batch_time AverageMeter() #批次时间 self.losses_tr AverageMeter() #训练损失 self.top1_tr AverageMeter() #Top-1准确率 self.top5_tr AverageMeter() #Top-5准确率 #3.切换模型为训练模式 self.model.train() #4.计算训练开始时间 self.end_tr time.time() #5.使用数据预处理器加速数据加载 prefetcher DataPrefetcher(self.train_loader,self.args.dataname) #6.获取第一次训练的数据 input,target prefetcher.next() i0 #7.训练开始 while input is not None: i 1 #8.计算应用学习率并应用到优化器更新 self.lr self.lr_policy.apply_lr(self.epoch) self.assign_learning_rate(self.lr) #9.前向传播混合精度 with autocast(): #自动选择float16 / bfloat16 output self.model(input) loss self.criterion(output,target) #10反向传播 self.optimizer.zero_grad() #11混合精度:缩放损失防止梯度下溢 self.scaler.scale(loss).backward() #11计算梯度范数用于监控梯度并保存梯度 total_norm 0 grad_k list() for p in self.model.parameters(): param_norm p.grad.data.norm(2) #L2范数 grad_k.append(p.grad.data.clone()) #保存梯度 total_norm param_norm.item() ** 2 #累加平方 total_norm total_norm ** (1./2) #开根号 self. grad_norm total_norm #记录梯度范数 #12.更新参数混合精度 self.scaler.step(self.optimizer) #根据scale后梯度更新参数 self.scaler.update() #调整scale因子 self.step 1 #计算step数 #13.计算准确率(Top-1 和 Top-5) self.prec1_tr,self.prec5_tr self.accuracy(output.data,target,topk(1,5)) #14.记录loss self.reduced_loss_tr loss.data #15.更新指标累计平均值 self.losses_tr.update(self.reduced_loss_tr.data.item(), input.size(0)) self.top1_tr.update(self.prec1_tr.data.item(), input.size(0)) self.top5_tr.update(self.prec5_tr.data.item(), input.size(0)) #16.写入TensorBoard self.write_net_values(trainTrue) #17.计算时间 self.batch_time.update(time.time() - self.end_tr()) #18.重置计时起点 self.end_tr time.time() # 19.定期打印训练信息 if i % self.args.print_freq_tr 0: # 当i被self.print_freq_tr整除时 # 计算吞吐量每秒处理图片数 self.curr_throughput_tr self.args.batch_size / self.batch_time.val self.avg_throughput_tr self.args.batch_size / self.batch_time.avg # 打印格式 print( fTrain: GPU数:{self.args.world_size:2} | \ fEpoch: [{self.epoch 1:2}/{self.args.epochs:2}] | \ fBatch: [{i:4}/{len(self.train_loader):4}] | \ f时间: {self.args.print_freq_tr * self.batch_time.avg:4.2f}s | \ f速度: {self.avg_throughput_tr:5.0f} pics/s\n \ f学习率: {self.lr:5.4f} | \ f当前loss: {self.losses_tr.val:3.2f} | \ f平均loss: {self.losses_tr.avg:3.2f} | \ fTop1: {self.top1_tr.avg:4.2f}% | \ fTop5: {self.top5_tr.avg:4.2f}%\n) | torch.cuda.synchronize() | 等待所有GPU计算完成 | | time.time() | 获取当前时间戳秒 | | self.end_tr | 上一个时间点用于计算时间差 | | self.batch_time.val | 当前batch耗时 | | self.batch_time.avg | 每个batch累计平均耗时 | 吞吐量相关 | 变量 | 说明 | |------|------| | self.args.world_size | GPU数量 | | self.args.batch_size | 每个GPU的batch大小 | | self.args.print_freq_tr | 打印频率每多少batch打印一次 | | self.curr_throughput_tr | 当前吞吐量图片/秒 | | self.avg_throughput_tr | 平均吞吐量 | #20.加载下一个batch input,target Prefetcher.next()验证流程 (validation方法)def validation(self,epoch,report): epoch当前epoch编号 report是否打印报告 #1.初始化需要使用到的参数 self.epoch epoch #2.创建指标计算器 self.batch_time_ts AverageMeter() #记录时间 self.losses_ts AverageMeter() #记录损失 self.top1_ts AverageMeter() #记录top-1准确率 self.top5_ts AverageMeter() #记录top-5准确率 ​ #3.切换到评估模式 self.model.eval() #4.计算评估开始时间 self.end_ts time.time() ​ #5.使用数据预处理器加速数据加载 prefetcher DataPrefetcher(self.validation,self.args.dataname) #6.获取第一个batch的数据 input,target prefetcher.next() i 0 #7.评估开始 while input is not None: i 1 #8.前向传播 with torch.no_grad(): output self.model(input) loss self.criterion(output,target) #9.计算准确率Top-1 和 Top-5 self.prec1_tr,self.prec5_tr self.accuracy(output,target,topk(1,5)) #10.记录损失 self.reduced_loss_tr loss.data #11.更新指标累计平均值 self.losses_ts.update(self.reduced_loss_ts.data.item(),input.size(0)) self.top1_ts.update(self.prec1_tr.data.item(),input.size(0)) self.top5_ts.update(self.prec5_tr.data.item(),input.size(0)) #12.计算时间 self.batch_time_ts.update(time.time() - self.end_ts) #13.重置计时起点 self.end_ts time.time() #14.获取下一个batch的数据 input,target prefetcher.next() #15.打印验证信息 self.args.print_freq_ts n 代表每n个batch打印一次 if i % self.args.print_freq_ts 0 and report: #计算吞吐量每秒处理图片数 args.batch_sizeGPU每批次处理的样本数 self.curr_thoughput_ts self.args.batch_size / self.batch_time_ts.val self.avg_thtoughput_ts self.args.batch_size / self.batch_time_ts.avg print( f验证: Epoch: [{self.epoch1:2}/{self.args.epochs:2}] | \ f已处理数据: [{i:4}/{len(self.validation_loader):4}] | \ f平均耗时: {self.args.print_freq_ts*self.batch_time_ts.avg:4.2f} | \ f速度: (张/秒): {self.avg_throughput_ts:5.0f}\n\ f当前损失: {self.losses_ts.val:3.2f} | \ f平均损失: {self.losses_ts.avg:3.2f} | \ fTop1准确率: {self.top1_ts.avg:4.2f} % | \ fTop5准确率: {self.top5_ts.avg:4.2f} %\n) #16.记录验证阶段的指标 self.write_net_values(trainFalse) #17返回平均损失准确率 return self.losses_ts.avg,self.top1_ts.avg,self.top5_ts.avg​核心工具类AverageMeter —— 指标追踪class AverageMeter(object): def __init__(self): self.reset() ​ def reset(self): self.val 0 # 当前值 self.avg 0 # 平均值 self.sum 0 # 累计和 self.count 0 # 样本数 ​ ​ # 更新计量器 def update(self, val, n1): val: 新值 n: 样本数量 self.val val self.sum val * n self.count n self.avg self.sum / self.count​DataPrefetcher —— 数据异步加载class DataPrefetcher(object): def __init__(self, loader, dataname): #将Dataloader转为迭代器 self.loader iter(loader) #数据集名称 self.dataname dataname #创建新的CUDA流 self.stream torch.cuda.Stream() ​ ​ #预加载第一个batch self.preload() def preload(self): #从Dataloader获取数据 try: self.next_input, self.next_target next(self.loader) except StopIteration: self.next_input None self.next_target None return ​ # 使用CUDA流异步传输数据 with torch.cuda.stream(self.stream):# 在独立 CUDA 流中异步传输 self.next_input self.next_input.cuda(non_blockingTrue) self.next_target self.next_target.cuda(non_blockingTrue) | self.stream | 独立的 CUDA 流类似后台线程 | | non_blockingTrue | 非阻塞传输不等待数据传输完成就返回 | | .cuda() | 将数据从 CPU 移到 GPU | ​ ​ # 转换为float并标准化 self.next_input self.next_input.float() ​ ​ #返回下一个batch def next(self): # 1. 等待异步传输完成 torch.cuda.current_stream().wait_stream(self.stream) ​ # 2. 获取当前 batch 的数据 input self.next_input target self.next_target ​ # 3. 记录数据使用的 CUDA 流 if input is not None: input.record_stream(torch.cuda.current_stream()) if target is not None: target.record_stream(torch.cuda.current_stream()) ​ # 4. 预取下一个 batch self.preload() ​ # 5. 返回当前 batch return input, target​ ​ ​accuracy —— Top-K准确率计算def accuracy(self,output,target,topk(1,)) topk需要计算的top-k 加逗号是为了让topk成为一个元组可以迭代遍历 #1.获取最大的k值 maxk max(topk) #2.获取batch的大小 batch_size target.size(0) #3.获取概率最高的前k个类别 _,pred output.topk(maxk,1,True,True) #[batch_size,maxk] #4.转置 pred pred.t() #[maxk,batch_sixe] #5.判断预测类别是否与真实类别一致 correct pred.eq(target.view(1,-1).expand_as(pred)) .view(1,-1) 将其变为 [1,batch_size] .expand_as(pred) 将target广播到 [maxk,batch_size] .eq() 逐个元素比较相等为True #6.创建保存概率的列表 res [] #7.计算每个k的准确率 for k in topk: correct_k correct[:k].reshape(-1).float().sum(0,keepdimTrue) res.append(correct_k.mul_(100.0 / batch_size)) .reshape(-1) 展平为1D张量 keepdim 保持维度不退化 return res4.5 main.py —— 主程序入口def main(): #1.解析命令行参数 args parse() #2.启动cudann自动优化 cudnn.benchmark True #3.创建当前场景名称包含时间戳和配置信息 scenario Scenario(args) #4.负责日志创建目录和TensorBoard写入 writer init_writer(args,scenario) #5.数据集加载 train_dataset,validation_dataset get_data(args) ​ #6.创建数据加载器 train_loader,validation_loader,train_sampler loader(train_dataset,validation_dataset,args) ​ ​ ​ #7.创建模型 model Archs(args).model.cuda(cuda_id) #8.模型初始化 network_init(model,args) #9.创建一阶优化器 optimizer SFO(model,args) #10.创建损失函数 if args.label_smoothing 0: criterion nn.CrossEntropyLoss().cuda() else: criterion LabelSmoothing(smoothingargs.label_smoothing) #11.混合精度训练 scaler torch.cuda.amp.GradScaler() #12.训练与验证 traval TraVal(model,train_loader,optimizer,criterion,scaler, args,validation_loader,writer writer if args.local_rank 0 else None, curr_scen_namescenario.curr_scen_name if args.local_rank 0 else None) ​ ​ #13.保存模型和变量 outputwriter Outputwriter(model,traval,writer,args,scenario.curr_scen_name) ​ ​ #14.保存初始模型参数 outputwriter.net_params_init() ​ ​ #15.训练 for epoch in range(args.initial_epoch,args.epochs): #训练一个epoch traval.train(epoch) #验证一个epoch traval.validation(epoch,reportTrue) ​ #16.保存训练过程的统计量 traval.stage_quantities() ​ #17.保存最终模型参数 if args.local_rank 0: outputwriter.net_params() ​ if __name__ __main__: main()​5. 关键概念深入理解5.1 线性层输入维度线性层接受 2D 张量(batch_size, features)features 卷积层提取的特征总数。示例 (Conv2, 输入32x32)卷积池化后: (batch, 64, 16, 16) 展平后: (batch, 64×16×16) (batch, 16384) fc1: (batch, 16384) → (batch, 256)5.2__all__的作用__all__ [MobileNetV2, mobilenet_v2]控制from xxx import *时可导入的内容场景有__all__无__all__from xxx import *只导入指定内容导入所有非_开头5.3 动态模型实例化torch_models.__dict__[args.arch]() # 动态获取并实例化 # 等价于 torch_models.resnet18()5.4 空洞卷积 vs 步长卷积操作方式效果步长卷积(stride2)跳跃采样特征图尺寸减半空洞卷积(dilation1)卷积核插入空洞感受野增大保持尺寸以使用步长卷积应用于下采样为例当残差块的 stride2 时输入输出尺寸不匹配需要下采样模块downsample nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride), norm_layer(planes * block.expansion), )作用改变尺寸stride2 将 28x28 → 14x14改变通道1x1 卷积可以改变通道数对齐形状让 F(x) 和 x 可以相加5.5 残差连接 (Residual Block)为什么需要残差连接解决深层网络训练困难对比普通网络残差网络公式y F(x)y F(x) x梯度容易消失恒等路径保持梯度退化深层性能下降至少不差于浅层5.6 1×1 卷积作用改变通道数 (64→128)跨通道信息融合减少计算量 (比3×3少9倍参数)在ResNet Bottleneck中的使用输入: 256通道 ↓ conv1x1: 256→64 (降维) ↓ conv3x3: 64→64 (提取特征) ↓ conv1x1: 64→256 (升维) 输出: 256通道6. 训练技巧6.1 权重初始化对比初始化方法公式适用场景Xavier (fan_in)保持前向传播方差不变Sigmoid, TanhXavier (fan_out)保持反向传播方差不变Sigmoid, TanhKaiming (fan_out)保持反向传播梯度稳定ReLU(推荐)Kaiming (fan_in)保持前向传播方差不变ReLU参数初始化 - mobilenet.py 201行 for m in self.modules(): # Kaiming 初始化适合 ReLU 激活函数 if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, modefan_out) #正太初始化 if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): # BN 权重初始化为 1偏置为 0 nn.init.ones_(m.weight) nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): # 线性层权重正态分布均值0标准差0.01 nn.init.normal_(m.weight, 0, 0.01) nn.init.zeros_(m.bias)6.2 Label Smoothing 标签平滑将硬标签转为软标签防止过拟合# 原始: [0, 0, 1, 0, 0] → 平滑后: [0.025, 0.025, 0.9, 0.025, 0.025] criterion LabelSmoothing(smoothing0.1)6.3 混合精度训练 (AMP)scaler torch.cuda.amp.GradScaler() ​ # 前向传播 with autocast(): output model(input) loss criterion(output, target) ​ # 反向传播 scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()6.4数据预取优化技术说明效果pin_memoryTrue锁页内存加速CPU→GPU传输num_workers数据加载线程数并行加载DataPrefetcherCUDA流异步预取消除数据等待7. 模型训练7.1 训练环境配置Windows环境特别配置# 必须设置否则会遇到OpenMP冲突错误 set KMP_DUPLICATE_LIB_OKTRUE7.2 训练命令基础训练命令# MNIST 全连接网络 python main.py -a fc --dataname mnist --config mnist_fc_train --epochs 10 --logterminal ​ # CIFAR-10 CNN模型 python main.py -a conv4 --dataname cifar10 --config cifar10_conv4_train --epochs 20 --logterminal7.3 参数说明参数缩写说明示例--arch-a模型架构fc,conv4,resnet18,vgg16--dataname-d数据集名称mnist,cifar10,cifar100--config-c配置文件名mnist_fc_train--epochs-e训练轮数10,50,100--batch-size-b批次大小32,64,128--lr-初始学习率0.001,0.01--gpu-ids-GPU设备ID0,0 1 2 3--logterminal-终端输出日志---resume-从检查点恢复./checkpoints/model.pt7.4 训练输出训练完成后会在以下目录生成输出文件history/ ├── logs/ # 训练日志 │ └── {数据集}/{模型}/ │ └── {时间戳}_{配置信息}.txt │ ├── saved/ # 保存的模型权重 │ └── {数据集}/{模型}/ │ ├── saved_initialization/ # 初始化权重训练前 │ │ └── {时间戳}/ │ │ └── initialization_{配置信息}.pth.tar │ │ │ └── saved_end_of_training/ # 训练结束后的权重 │ └── {时间戳}/ │ └── _data_{配置信息}.pth.tar │ └── variables/ # NumPy格式训练指标 └── {数据集}/{模型}/ └── {时间戳}/ ├── {时间戳}TrainingLoss_{配置信息}.npy ├── {时间戳}TrainingTop1_{配置信息}.npy ├── {时间戳}TrainingTop5_{配置信息}.npy ├── {时间戳}TestLoss_{配置信息}.npy ├── {时间戳}TestTop1_{配置信息}.npy ├── {时间戳}TestTop5_{配置信息}.npy ├── {时间戳}Learning_Rate_{配置信息}.npy └── {时间戳}GradNorm_{配置信息}.npy7.5模型训练结果8. 模型评估8.1 评估模块概述项目提供了独立的模型评估模块evaluate.py支持加载训练好的模型权重在测试集上计算准确率生成混淆矩阵输出详细分类报告8.2 评估命令# 基本评估 python utils/evaluate.py --checkpoint 训练好的模型参数路径 --dataname 参数对应的数据集名称 --arch 模型名称8.3 评估指标指标说明准确率(Accuracy)正确预测样本数 / 总样本数损失(Loss)交叉熵损失Top-1 准确率预测概率最高的类别正确率Top-5 准确率预测概率前5包含正确类别的概率(主要用于ImageNet)8.4 混淆矩阵评估模块会自动生成混淆矩阵可视化图直观展示各类别的分类效果def plot_confusion_matrix(target_list, pred_list, classes, save_path): cm confusion_matrix(target_list, pred_list) plt.figure(figsize(12, 10)) sns.heatmap(cm, annotTrue, fmtd, cmapBlues, xticklabelsclasses, yticklabelsclasses) plt.xlabel(Predicted Label) plt.ylabel(True Label) plt.title(Confusion Matrix) plt.savefig(save_path, dpi150)8.5 分类报告评估模块输出详细的分类报告包括每个类别的指标说明Precision精确率预测为正类的样本中实际为正类的比例Recall召回率实际为正类的样本中被正确预测的比例F1-score精确率和召回率的调和平均Support该类别的样本数量8.6 模型评估结果注本项目中使用了多GPU即分布式训练的方法但是由于本人实际上大概率使用不到该方法所以并没有对该方法进行总结感兴趣的可以查看项目源码项目源码地址

相关新闻

F.动态规划-入门DP-打家劫舍:198. 打家劫舍

F.动态规划-入门DP-打家劫舍:198. 打家劫舍

题目链接:198. 打家劫舍(中等) LCR 089. 打家劫舍(中等) 算法原理: 此题与下题完全相同👇 动态规划算法-简单多状态dp问题:11.按摩师 解法:动态规划 0ms击败100.00% 时间…

2026/7/4 16:43:08 阅读更多 →
“三晋优品 乐购云端”直播活动开始啦!

“三晋优品 乐购云端”直播活动开始啦!

2026/7/3 4:45:33 阅读更多 →
c/c++高频面试:TCP粘包三种解决方案

c/c++高频面试:TCP粘包三种解决方案

1. 消息定长 (Fixed-Length Messages)原理:发送端和接收端约定,每一个消息的长度都是固定的(比如 1024 字节)。实现:如果发送的数据不足 1024 字节,就用空格或 \0 补齐;接收端每次严格读取 1024…

2026/7/3 10:41:26 阅读更多 →

最新新闻

网盘直链下载助手完整指南:一键获取八大网盘真实下载地址的终极解决方案

网盘直链下载助手完整指南:一键获取八大网盘真实下载地址的终极解决方案

网盘直链下载助手完整指南:一键获取八大网盘真实下载地址的终极解决方案 【免费下载链接】Online-disk-direct-link-download-assistant 一个基于 JavaScript 的网盘文件下载地址获取工具。基于【网盘直链下载助手】修改 ,支持 百度网盘 / 阿里云盘 / 中…

2026/7/5 18:33:28 阅读更多 →
如何扩展Runno:添加自定义编程语言运行时的完整指南

如何扩展Runno:添加自定义编程语言运行时的完整指南

如何扩展Runno:添加自定义编程语言运行时的完整指南 【免费下载链接】runno Sandboxed runtime for programming languages and WASI binaries. Works in the browser, on your server, or via MCP. 项目地址: https://gitcode.com/gh_mirrors/ru/runno Runn…

2026/7/5 18:33:28 阅读更多 →
对字符串排序的影响

对字符串排序的影响

字符串的大小比较并不是如C那样按照字符串字符内码大小顺序从头到尾来比较的。由于我是从C/C转过来的,我一直以来都以为.net 下字符串的比较规则和C是一样的,直到有一天我的程序在英文操作系统下出错。 .net 下,字符串的排序受 System.Threa…

2026/7/5 18:29:28 阅读更多 →
Runno高级调试技巧:解决复杂代码执行问题的完整方法

Runno高级调试技巧:解决复杂代码执行问题的完整方法

Runno高级调试技巧:解决复杂代码执行问题的完整方法 【免费下载链接】runno Sandboxed runtime for programming languages and WASI binaries. Works in the browser, on your server, or via MCP. 项目地址: https://gitcode.com/gh_mirrors/ru/runno Runn…

2026/7/5 18:29:28 阅读更多 →
Instatic集群部署:负载均衡与会话共享配置指南

Instatic集群部署:负载均衡与会话共享配置指南

Instatic集群部署:负载均衡与会话共享配置指南 【免费下载链接】Instatic Instatic is a modern self-hosted visual CMS - get it running in 1 minute 项目地址: https://gitcode.com/GitHub_Trending/in/Instatic Instatic作为一款现代自托管视觉CMS&…

2026/7/5 18:25:26 阅读更多 →
CANN/asc-devkit:int8转half数据类型转换API

CANN/asc-devkit:int8转half数据类型转换API

asc_int82half 【免费下载链接】asc-devkit 本项目是CANN 推出的昇腾AI处理器专用的算子程序开发语言,原生支持C和C标准规范,主要由类库和语言扩展层构成,提供多层级API,满足多维场景算子开发诉求。 项目地址: https://gitcode.…

2026/7/5 18:25:26 阅读更多 →

日新闻

B站视频下载神器BiliTools:5分钟学会轻松保存任何B站内容

B站视频下载神器BiliTools:5分钟学会轻松保存任何B站内容

B站视频下载神器BiliTools:5分钟学会轻松保存任何B站内容 【免费下载链接】BiliTools A cross-platform bilibili toolbox. 跨平台哔哩哔哩工具箱,支持下载视频、番剧等等各类资源 项目地址: https://gitcode.com/GitHub_Trending/bilit/BiliTools …

2026/7/5 0:03:34 阅读更多 →
威胁模型全解析:从新手入门到实战应用,助你构建安全产品!

威胁模型全解析:从新手入门到实战应用,助你构建安全产品!

威胁模型的陌生现状在忙碌疲惫的一天里,参与了关于混合后量子密码学的讨论,应付端点攻击找茬的人,还参与留言板讨论后,发现“威胁模型”对多数人仍是陌生概念,且多被当作时髦用语。有趣的相关画作有一幅由 Embyr 创作的…

2026/7/5 0:03:34 阅读更多 →
渗透测试入门指南:从零基础到实战环境搭建

渗透测试入门指南:从零基础到实战环境搭建

1. 从“看热闹”到“入门”:我理解的渗透测试到底是什么?每次看到新闻里说某个大公司的数据被“黑”了,或者某个网站被攻击导致服务瘫痪,你是不是和我一样,心里会冒出两个念头:一是“这黑客真厉害”&#x…

2026/7/5 0:07:38 阅读更多 →

周新闻

B站视频下载神器BiliTools:5分钟学会轻松保存任何B站内容

B站视频下载神器BiliTools:5分钟学会轻松保存任何B站内容

B站视频下载神器BiliTools:5分钟学会轻松保存任何B站内容 【免费下载链接】BiliTools A cross-platform bilibili toolbox. 跨平台哔哩哔哩工具箱,支持下载视频、番剧等等各类资源 项目地址: https://gitcode.com/GitHub_Trending/bilit/BiliTools …

2026/7/5 0:03:34 阅读更多 →
威胁模型全解析:从新手入门到实战应用,助你构建安全产品!

威胁模型全解析:从新手入门到实战应用,助你构建安全产品!

威胁模型的陌生现状在忙碌疲惫的一天里,参与了关于混合后量子密码学的讨论,应付端点攻击找茬的人,还参与留言板讨论后,发现“威胁模型”对多数人仍是陌生概念,且多被当作时髦用语。有趣的相关画作有一幅由 Embyr 创作的…

2026/7/5 0:03:34 阅读更多 →
渗透测试入门指南:从零基础到实战环境搭建

渗透测试入门指南:从零基础到实战环境搭建

1. 从“看热闹”到“入门”:我理解的渗透测试到底是什么?每次看到新闻里说某个大公司的数据被“黑”了,或者某个网站被攻击导致服务瘫痪,你是不是和我一样,心里会冒出两个念头:一是“这黑客真厉害”&#x…

2026/7/5 0:07:38 阅读更多 →

月新闻