基于 ResNet18 迁移学习的猫狗多分类系统(Oxford-IIIT Pet 数据集)
1.技术栈Python、PyTorch、torchvision、OpenCV、ResNet18、迁移学习、数据增强、模型训练与评估、实验可视化2.核心工作与量化成果基于 PyTorch 完成37 类猫狗分类端到端算法开发采用官方 Oxford-IIIT Pet 数据集7349 张图像通过 OpenCV 与 torchvision 实现数据预处理尺寸归一化 / ImageNet 标准化和数据增强随机裁剪 / 水平翻转有效提升样本多样性缓解过拟合设计迁移学习训练策略加载 ResNet18 预训练权重冻结卷积特征提取层仅训练全连接分类头将输入特征适配 37 类分类任务较从零训练收敛速度提升 80%大幅降低训练成本与硬件资源消耗搭建完整的训练 - 验证 - 测试闭环划分训练 / 验证 / 测试集7:2:1配置 Adam 优化器 交叉熵损失函数设置 5 轮训练迭代实时监控损失与准确率指标最终在测试集实现 84.08% 分类准确率实现模型性能分析与可视化绘制训练 / 验证阶段损失 - 准确率曲线分析模型收敛趋势保存最佳模型权重形成可复现的算法实验流程为后续模型调优超参数优化 / 网络升级提供基础基于 DataLoader 实现高效数据批量加载针对 Windows 环境优化多线程配置保证数据加载与模型训练的兼容性和稳定性。3.代码展示1.导库# PyTorch核心构建张量、模型、优化器、数据加载器 import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, random_split # torchvisionPyTorch官方视觉库包含内置数据集、预处理、经典模型 from torchvision import datasets, transforms, models from torchvision.datasets import OxfordIIITPet # 重点PyTorch内置的猫狗分类数据集 # 基础库仅用于路径处理和可视化非核心可选 import os import matplotlib.pyplot as plt2.全局配置#DEVICE是后续所有张量/模型的运行设备统一配置避免设备不匹配报错 DEVICE torch.device(cuda if torch.cuda.is_available() else cpu) print(f【全局配置】运行设备{DEVICE}) # 数据集根路径PyTorch会自动下载数据集到这个路径 # 建议用绝对路径避免相对路径导致的“找不到文件”问题 DATA_ROOT os.path.join(os.path.expanduser(~), Desktop, oxford_pet_data) # 创建路径如果不存在确保PyTorch有写入权限 os.makedirs(DATA_ROOT, exist_okTrue) print(f【全局配置】数据集保存路径{DATA_ROOT}) # 训练超参数新手友好注释说明调整原则 BATCH_SIZE 16 # 批次大小GPU内存小→改8/4内存大→改32/64 EPOCHS 5 # 训练轮数测试用5轮实际训练可改10/20轮 LEARNING_RATE 0.001 # 学习率越小训练越稳定越大收敛越快但易震荡 VAL_SPLIT 0.2 # 验证集比例从训练集中划分20%用于验证避免过拟合3.数据预处理def get_transforms(): 定义数据预处理规则 核心目的将原始图片转换为模型可识别的张量并通过数据增强提升泛化能力 分为训练集带增强和验证/测试集无增强两类规则 # 训练集预处理加入数据增强解决样本量不足提升模型鲁棒性 train_transform transforms.Compose([ # 调整尺寸统一缩放到256x256后续随机裁剪到224x224 transforms.Resize((256, 256)), # 随机裁剪从256x256中随机裁224x224模拟不同视角 transforms.RandomResizedCrop(224), # 随机水平翻转50%概率翻转增加数据多样性猫狗左右翻转不影响分类 transforms.RandomHorizontalFlip(p0.5), # 转换为张量将PIL图片/NumPy数组转为torch.Tensor像素值从0-255归一化到0-1 transforms.ToTensor(), # 标准化使用ImageNet均值/标准差适配预训练模型的输入分布 # 公式output (input - mean) / std transforms.Normalize( mean[0.485, 0.456, 0.406], # RGB三通道均值 std[0.229, 0.224, 0.225] # RGB三通道标准差 ) ]) # 验证/测试集预处理仅做基础变换保证评估的公平性 val_test_transform transforms.Compose([ transforms.Resize((224, 224)), # 直接缩放到224x224无随机 transforms.ToTensor(), # 仅转换张量 transforms.Normalize( # 同训练集的标准化必须一致 mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225] ) ]) return train_transform, val_test_transform4.加载内置数据集def load_oxford_pet_dataset(): 加载PyTorch内置的Oxford-IIIT Pet数据集 该数据集包含37类宠物25类狗12类猫共7349张图片是官方标配的猫狗分类数据集 关键参数downloadTrue → PyTorch自动下载并解压无需手动处理 # 获取预处理规则 train_transform, val_test_transform get_transforms() # 加载训练验证集splittrainval # root数据集保存路径 # split选择数据集划分trainval训练验证test测试 # transform应用的预处理规则 # downloadTrue自动下载首次运行下载后续跳过 full_train_dataset OxfordIIITPet( root DATA_ROOT, split trainval, transform train_transform, download True # 核心PyTorch自动处理下载/解压无需手动操作 ) # 加载测试集 test_dataset OxfordIIITPet( root DATA_ROOT, split test, transform val_test_transform, download True ) # 打印数据集基本信息帮助理解数据规模 print(\n【数据集信息】) print(f - 训练验证集总数{len(full_train_dataset)} 张) print(f - 测试集总数{len(test_dataset)} 张) print(f - 分类类别数{len(full_train_dataset.classes)} 类25类狗12类猫) print(f - 类别示例{full_train_dataset.classes[:5]}前5类) # 划分训练集和验证集从trainval中拆分 # 计算划分数量80%训练20%验证 train_size int((1 - VAL_SPLIT) * len(full_train_dataset)) val_size len(full_train_dataset) - train_size # random_split随机划分数据集保证随机性 train_dataset, val_dataset random_split(full_train_dataset, [train_size, val_size]) # 关键验证集替换为无增强的预处理避免增强影响验证结果 val_dataset.dataset.transform val_test_transform # 构建DataLoader批量加载数据是PyTorch训练的核心数据接口 # shuffleTrue训练集打乱顺序提升训练效果验证/测试集False # num_workers0Windows系统避免多线程报错Mac/Linux可改2/4 train_loader DataLoader( train_dataset, batch_size BATCH_SIZE, shuffle True, num_workers 0 ) val_loader DataLoader( val_dataset, batch_size BATCH_SIZE, shuffle False, num_workers 0 ) test_loader DataLoader( test_dataset, batch_size BATCH_SIZE, shuffle False, num_workers 0 ) print(\n【DataLoader信息】) print(f - 训练集批次{len(train_loader)} 批每批{BATCH_SIZE}张) print(f - 验证集批次{len(val_loader)} 批) print(f - 测试集批次{len(test_loader)} 批) return train_loader, val_loader, test_loader, full_train_dataset.classes5.构建迁移学习模型def build_model(num_classes): 构建基于ResNet18的迁移学习模型 核心思路复用预训练模型的特征提取能力仅训练最后一层分类头高效、快速收敛 ResNet18是轻量级经典模型适合猫狗分类这类简单任务 print(\n【模型构建】加载预训练ResNet18模型...) # 加载预训练的ResNet18pretrainedTrue → 使用ImageNet预训练权重 # 预训练权重包含大量视觉特征无需从零训练大幅提升效果 model models.resnet18(pretrainedTrue) # 冻结特征提取层所有卷积层 # requires_gradFalse → 反向传播时不计算梯度不更新参数 # 目的只训练最后一层分类头节省计算资源避免过拟合 for param in model.parameters(): param.requires_grad False # 替换最后一层全连接层适配Oxford Pet的37类分类 # model.fc是ResNet18的最后一层默认输出1000类ImageNet # 步骤1获取最后一层的输入特征数 in_features model.fc.in_features # 步骤2替换为新的全连接层输出类别数37 model.fc nn.Linear(in_features, num_classes) # 将模型移到指定设备GPU/CPU所有后续计算都在该设备上 model model.to(DEVICE) print(f✅ 模型构建完成) print(f - 特征提取层冻结ResNet18预训练) print(f - 分类头{in_features} → {num_classes} 类) return model6.训练/验证函数def train_one_epoch(model, train_loader, criterion, optimizer, epoch): 训练模型一个轮次Epoch 核心流程前向传播→计算损失→反向传播→更新参数 # 切换模型到训练模式启用Dropout、BatchNorm等训练相关层 model.train() total_loss 0.0 # 累计损失 correct 0 # 累计正确预测数 total 0 # 累计样本数 print(f\n【训练Epoch {epoch 1}/{EPOCHS}】) # 遍历训练集所有批次 for batch_idx, (images, labels) in enumerate(train_loader): # 步骤1将数据移到指定设备必须和模型同设备否则报错 images images.to(DEVICE) labels labels.to(DEVICE) # 步骤2前向传播模型预测 outputs model(images) # outputs.shape (BATCH_SIZE, 37) # 步骤3计算损失交叉熵损失多分类任务标配 # criterionnn.CrossEntropyLoss()自动计算softmax负对数似然 loss criterion(outputs, labels) # 步骤4反向传播参数更新核心三步 optimizer.zero_grad() # 清空上一批次的梯度必须否则梯度累加 loss.backward() # 反向传播计算各参数的梯度 optimizer.step() # 优化器更新参数根据梯度调整权重 # 步骤5统计训练指标 total_loss loss.item() * images.size(0) # 累计损失乘以批次大小 # torch.max获取预测类别取outputs中最大值的索引 _, predicted torch.max(outputs, 1) total labels.size(0) # 累计样本数 correct (predicted labels).sum().item() # 累计正确数 # 每20批打印一次进度避免刷屏 if (batch_idx 1) % 20 0: batch_acc 100 * correct / total print(f 批次 {batch_idx1}/{len(train_loader)} | 损失{loss.item():.4f} | 准确率{batch_acc:.2f}%) # 计算本轮平均损失和准确率 avg_loss total_loss / len(train_loader.dataset) avg_acc 100 * correct / total print(f【训练结果】Epoch {epoch1} | 平均损失{avg_loss:.4f} | 平均准确率{avg_acc:.2f}%) return avg_loss, avg_acc def validate(model, val_loader, criterion,epoch): 验证模型性能 核心区别关闭梯度计算torch.no_grad()避免占用内存不更新参数 # 切换模型到验证模式关闭Dropout、BatchNorm等训练层 model.eval() total_loss 0.0 correct 0 total 0 print(f\n【验证Epoch {epoch1}/{EPOCHS}】) # torch.no_grad()禁用梯度计算大幅提升验证速度节省内存 with torch.no_grad(): for images, labels in val_loader: images images.to(DEVICE) labels labels.to(DEVICE) # 仅前向传播无反向传播 outputs model(images) loss criterion(outputs, labels) # 统计指标同训练 total_loss loss.item() * images.size(0) _, predicted torch.max(outputs, 1) total labels.size(0) correct (predicted labels).sum().item() # 计算验证指标 avg_loss total_loss / len(val_loader.dataset) avg_acc 100 * correct / total print(f【验证结果】Epoch {epoch1} | 平均损失{avg_loss:.4f} | 平均准确率{avg_acc:.2f}%) return avg_loss, avg_acc7.主流程def main(): 主流程加载数据集→构建模型→训练→验证→测试→可视化 全程基于PyTorch内置数据集无需手动下载/解压 # 步骤1加载内置数据集自动下载 train_loader, val_loader, test_loader, classes load_oxford_pet_dataset() num_classes len(classes) # 步骤2构建迁移学习模型 model build_model(num_classes) # 步骤3定义损失函数和优化器 # 损失函数CrossEntropyLoss → 多分类任务首选 criterion nn.CrossEntropyLoss() # 优化器Adam → 收敛快适合迁移学习仅优化分类头参数 optimizer optim.Adam(model.parameters(), lrLEARNING_RATE) # 步骤4训练过程记录 train_losses [] # 训练损失记录 val_losses [] # 验证损失记录 train_accs [] # 训练准确率记录 val_accs [] # 验证准确率记录 best_val_acc 0.0 # 保存最佳验证准确率 # 步骤5迭代训练 print(\n *80) print(【开始训练】全程使用PyTorch内置Oxford-IIIT Pet数据集) print(*80) for epoch in range(EPOCHS): # 训练一轮 train_loss, train_acc train_one_epoch(model, train_loader, criterion, optimizer, epoch) # 验证一轮 val_loss, val_acc validate(model, val_loader, criterion,epoch) # 记录指标 train_losses.append(train_loss) val_losses.append(val_loss) train_accs.append(train_acc) val_accs.append(val_acc) # 保存最佳模型验证准确率更高时 if val_acc best_val_acc: best_val_acc val_acc save_path os.path.join(DATA_ROOT, best_cat_dog_model.pth) torch.save(model.state_dict(), save_path) print(f 保存最佳模型验证准确率 {best_val_acc:.2f}% → {save_path}) # 步骤6测试最佳模型 print(\n *80) print(【测试最佳模型】) print(*80) # 加载最佳模型权重 model.load_state_dict(torch.load(save_path)) model.eval() test_correct 0 test_total 0 with torch.no_grad(): for images, labels in test_loader: images images.to(DEVICE) labels labels.to(DEVICE) outputs model(images) _, predicted torch.max(outputs, 1) test_total labels.size(0) test_correct (predicted labels).sum().item() test_acc 100 * test_correct / test_total print(f✅ 测试完成 | 测试集准确率{test_acc:.2f}%) # 步骤7可视化训练曲线中文显示 plt.rcParams[font.sans-serif] [SimHei] # 解决中文乱码 plt.figure(figsize(12, 5)) # 损失曲线 plt.subplot(1, 2, 1) plt.plot(train_losses, label训练损失, markero) plt.plot(val_losses, label验证损失, markers) plt.title(训练/验证损失变化) plt.xlabel(Epoch) plt.ylabel(Loss) plt.legend() plt.grid(alpha0.3) # 准确率曲线 plt.subplot(1, 2, 2) plt.plot(train_accs, label训练准确率, markero) plt.plot(val_accs, label验证准确率, markers) plt.title(训练/验证准确率变化) plt.xlabel(Epoch) plt.ylabel(准确率(%)) plt.legend() plt.grid(alpha0.3) plt.tight_layout() plt.show()8.运行主函数if __name__ __main__: # 首次运行会自动下载Oxford-IIIT Pet数据集约700MB后续运行跳过 main()4.运行结果展示1.数据集信息2.训练信息3.测试信息祝努力的你闪闪发光

相关新闻

千问进校园:当课堂遇上AI,每个孩子都有了“专属导师”

千问进校园:当课堂遇上AI,每个孩子都有了“专属导师”

引言 在科技快速发展的今天,人工智能已经逐渐渗透到我们生活的各个领域,教育行业也不例外。阿里巴巴推出的千问大模型,正在为教育领域带来一场前所未有的变革。当课堂遇上AI,每个孩子都有了属于自己的“专属导师”,个性…

2026/7/5 13:16:22 阅读更多 →
2026必学!AI大模型架构全解析:基础模型、微调与插件谁更重要?(收藏必备)

2026必学!AI大模型架构全解析:基础模型、微调与插件谁更重要?(收藏必备)

2025年,Llama 3 405B模型凭借15T tokens的预训练语料实现通用能力的跨越式提升,文心一言4.0通过知识增强架构在中文合规场景脱颖而出,ChatGPT-4o则依靠插件生态完成从文本交互到多任务处理的进化——这些主流大模型的成功,背后都离…

2026/5/17 9:16:50 阅读更多 →
多模态文档智能解析教程(非常详细),Youtu-Parsing模型从架构到训练,收藏这一篇就够了!

多模态文档智能解析教程(非常详细),Youtu-Parsing模型从架构到训练,收藏这一篇就够了!

优图开源一个多模态文档解析模型-Youtu-Parsing-2.5B,这是一个以多模态视觉语言模型为基础的pipeline结构(即:vlm既做layout版式分析又做ocr Format识别),并使用高并行性解码策略解决传统文档解析中 “自回归解码速度慢…

2026/5/17 9:16:49 阅读更多 →

最新新闻

零日漏洞攻防实战:从检测到响应的纵深防御体系构建

零日漏洞攻防实战:从检测到响应的纵深防御体系构建

1. 项目概述:直面数字世界的“隐形杀手”在网络安全这个没有硝烟的战场上,最让防御者感到棘手的,往往不是那些已知的、有补丁可循的威胁,而是那些被称为“零日漏洞”的未知攻击。从业十几年,我处理过无数次安全事件&am…

2026/7/5 13:16:07 阅读更多 →
多人聊天室

多人聊天室

一、项目简介本项目是一个基于Java Swing MySQL的博客文章管理系统,实现了文章发布、分类管理、用户登录、全局搜索等核心功能。 我在项目中主要负责全局搜索模块、数据库读写层设计以及部分面向对象架构设计工作。二、个人任务简述序号完成功能与任务描述1全局搜索…

2026/7/5 13:14:06 阅读更多 →
骑乘无忧怎么选 (新手女生小个子巡航摩托)选购要点

骑乘无忧怎么选 (新手女生小个子巡航摩托)选购要点

入手自动挡巡航摩托,CVT 和 AMT 该怎么选?面向入门骑手、女性车友以及身高娇小的人群,最优方案已然明确。AMT 巡航操控顺手、动力充沛、使用便捷,外观也十分出彩,是综合实力更强的选择。QJMOTOR 闪 300AMT 与闪 400AMT…

2026/7/5 13:14:06 阅读更多 →
Azure Local离线模式采购(系列篇之七)

Azure Local离线模式采购(系列篇之七)

0. 重要定位(先看清 Acquire 在做什么) ⚠️ Acquire ≠ 部署完成。Acquire 阶段仅完成 Azure 资源创建及部署介质获取,Virtual Appliance 尚未部署到本地数据中心。完整的生命周期是: Acquire → Deploy → Configure → Operate…

2026/7/5 13:12:06 阅读更多 →
杭州老板IP打造运营公司怎么选?

杭州老板IP打造运营公司怎么选?

选择杭州的老板IP打造运营公司时,可以从以下几个方面进行考量:一、明确需求与目标核心需求:首先明确你希望通过IP打造实现什么目的。是增加品牌知名度、提升客户信任度,还是直接促进销售转化? 行业特性:根据…

2026/7/5 13:12:06 阅读更多 →
input_report_key + input_sync:按键事件的正确报告姿势

input_report_key + input_sync:按键事件的正确报告姿势

input_report_key input_sync:按键事件的正确报告姿势这个仓库已经开源!所有教程,主线内核移植,跑新版本imx-linux/uboot都在这里,或者一起来尝试跑7.1的Linux!欢迎各位大佬观摩!喜欢的话点个⭐…

2026/7/5 13:10:06 阅读更多 →

日新闻

B站视频下载神器BiliTools:5分钟学会轻松保存任何B站内容

B站视频下载神器BiliTools:5分钟学会轻松保存任何B站内容

B站视频下载神器BiliTools:5分钟学会轻松保存任何B站内容 【免费下载链接】BiliTools A cross-platform bilibili toolbox. 跨平台哔哩哔哩工具箱,支持下载视频、番剧等等各类资源 项目地址: https://gitcode.com/GitHub_Trending/bilit/BiliTools …

2026/7/5 0:03:34 阅读更多 →
威胁模型全解析:从新手入门到实战应用,助你构建安全产品!

威胁模型全解析:从新手入门到实战应用,助你构建安全产品!

威胁模型的陌生现状在忙碌疲惫的一天里,参与了关于混合后量子密码学的讨论,应付端点攻击找茬的人,还参与留言板讨论后,发现“威胁模型”对多数人仍是陌生概念,且多被当作时髦用语。有趣的相关画作有一幅由 Embyr 创作的…

2026/7/5 0:03:34 阅读更多 →
渗透测试入门指南:从零基础到实战环境搭建

渗透测试入门指南:从零基础到实战环境搭建

1. 从“看热闹”到“入门”:我理解的渗透测试到底是什么?每次看到新闻里说某个大公司的数据被“黑”了,或者某个网站被攻击导致服务瘫痪,你是不是和我一样,心里会冒出两个念头:一是“这黑客真厉害”&#x…

2026/7/5 0:07:38 阅读更多 →

周新闻

B站视频下载神器BiliTools:5分钟学会轻松保存任何B站内容

B站视频下载神器BiliTools:5分钟学会轻松保存任何B站内容

B站视频下载神器BiliTools:5分钟学会轻松保存任何B站内容 【免费下载链接】BiliTools A cross-platform bilibili toolbox. 跨平台哔哩哔哩工具箱,支持下载视频、番剧等等各类资源 项目地址: https://gitcode.com/GitHub_Trending/bilit/BiliTools …

2026/7/5 0:03:34 阅读更多 →
威胁模型全解析:从新手入门到实战应用,助你构建安全产品!

威胁模型全解析:从新手入门到实战应用,助你构建安全产品!

威胁模型的陌生现状在忙碌疲惫的一天里,参与了关于混合后量子密码学的讨论,应付端点攻击找茬的人,还参与留言板讨论后,发现“威胁模型”对多数人仍是陌生概念,且多被当作时髦用语。有趣的相关画作有一幅由 Embyr 创作的…

2026/7/5 0:03:34 阅读更多 →
渗透测试入门指南:从零基础到实战环境搭建

渗透测试入门指南:从零基础到实战环境搭建

1. 从“看热闹”到“入门”:我理解的渗透测试到底是什么?每次看到新闻里说某个大公司的数据被“黑”了,或者某个网站被攻击导致服务瘫痪,你是不是和我一样,心里会冒出两个念头:一是“这黑客真厉害”&#x…

2026/7/5 0:07:38 阅读更多 →

月新闻