# train_classify **Repository Path**: damone/train_classify ## Basic Information - **Project Name**: train_classify - **Description**: No description available - **Primary Language**: Unknown - **License**: Not specified - **Default Branch**: main - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2025-10-12 - **Last Updated**: 2025-10-12 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # 深度学习训练项目文档 ## 项目概述 本项目实现了使用多种深度学习模型在不同数据集(MNIST手写数字识别和Flowers花朵分类)上进行训练的完整流程。项目支持ResNet、 TinyViT 三种模型架构,提供了训练指标的实时可视化和保存功能。 TinyViT是一种新型的高效视觉Transformer,通过在大规模数据集上使用快速蒸馏框架进行预训练。核心思想是将知识从大型预训练模型转移到小型模型。大型教师模型的logits会被稀疏化并提前存储在磁盘中,以节省内存成本和计算开销。 ## 项目结构 ``` train-dl/ ├── data/ # 数据集目录 │ ├── MNIST/ # MNIST数据集 │ └── flower_photos/ # Flowers数据集 ├── resnet18_mnist/ # MNIST训练结果(ResNet-18) │ ├── resnet18_mnist.pth # 训练好的模型权重 │ ├── training_metrics.npz # 训练指标数据 │ ├── training_metrics.png # 训练指标图表 │ └── plot_metrics.py # 指标可视化脚本 ├── resnet50_mnist/ # MNIST训练结果(ResNet-50) │ ├── resnet50_mnist.pth # 训练好的模型权重 │ ├── training_metrics.npz # 训练指标数据 │ └── training_metrics.png # 训练指标图表 ├── resnet18_flowers/ # Flowers训练结果(ResNet-18) │ ├── resnet18_flowers.pth # 训练好的模型权重 │ ├── training_metrics_flowers.npz# 训练指标数据 │ └── training_metrics_flowers.png# 训练指标图表 ├── resnet50_flowers/ # Flowers训练结果(ResNet-50) │ ├── resnet50_flowers.pth # 训练好的模型权重 │ ├── training_metrics_flowers.npz# 训练指标数据 │ └── training_metrics_flowers.png# 训练指标图表 ├── vit_mnist/ # MNIST训练结果(ViT) │ ├── vit_mnist.pth # ViT训练好的模型权重 │ ├── training_metrics.npz # ViT训练指标数据 │ └── training_metrics.png # ViT训练指标图表 ├── vit_flowers/ # Flowers训练结果(ViT) │ ├── vit_flowers.pth # ViT训练好的模型权重 │ ├── training_metrics.npz # ViT训练指标数据 │ └── training_metrics.png # ViT训练指标图表 ├── TinyViT/ # TinyViT模型实现及相关文件 │ ├── models/ # TinyViT模型定义 │ ├── data/ # 数据处理相关代码 │ ├── configs/ # 配置文件 │ └── ... # 其他TinyViT相关文件 ├── resnet.py # ResNet模型实现 ├── vit.py # Vision Transformer模型实现 ├── train_mnist_resnet18.py # MNIST数据集训练脚本(ResNet-18) ├── train_mnist_resnet50.py # MNIST数据集训练脚本(ResNet-50) ├── train_flowers_resnet18.py # Flowers数据集训练脚本(ResNet-18) ├── train_flowers_resnet50.py # Flowers数据集训练脚本(ResNet-50) ├── train_mnist_vit.py # MNIST数据集训练脚本(ViT) ├── train_flowers_vit.py # Flowers数据集训练脚本(ViT) ├── print_resnet50.py # 打印ResNet50模型结构 ├── requirements.txt # 项目依赖 └── README.md # 项目文档 ``` ## 模型实现 ### ResNet模型 [resnet.py](file:///home/damone/project/train-dl/resnet.py)文件中包含了ResNet模型实现: 1. `Residual` - 基本残差块,expansion为1 2. `Bottleneck` - 瓶颈块,expansion为4 3. `resnet18()` - 适用于MNIST数据集(单通道图像) 4. `resnet18_rgb()` - 适用于CIFAR-10数据集(三通道图像) 5. `resnet50()` - 适用于MNIST数据集(单通道图像) 6. `resnet50_rgb()` - 适用于Flowers数据集(三通道图像) ### TinyViT [TinyViT](file:///home/damone/project/train-dl/TinyViT)目录中包含了TinyViT模型的完整实现: 1. 高效的小型视觉Transformer模型 2. 支持大规模数据集上的快速预训练蒸馏 3. TinyViT-21M仅使用21M参数就在ImageNet-1k上达到84.8%的top-1准确率 4. 支持更高分辨率(384x384和512x512)的微调版本 ## 训练脚本 ### MNIST数据集训练 1. [train_mnist_resnet18.py](file:///home/damone/project/train-dl/train_mnist_resnet18.py) - 使用ResNet-18模型训练MNIST数据集 2. [train_mnist_resnet50.py](file:///home/damone/project/train-dl/train_mnist_resnet50.py) - 使用ResNet-50模型训练MNIST数据集 3. [train_mnist_vit.py](file:///home/damone/project/train-dl/train_mnist_vit.py) - 使用ViT模型训练MNIST数据集 ### Flowers数据集训练 1. [train_flowers_resnet18.py](file:///home/damone/project/train-dl/train_flowers_resnet18.py) - 使用ResNet-18模型训练Flowers数据集 2. [train_flowers_resnet50.py](file:///home/damone/project/train-dl/train_flowers_resnet50.py) - 使用ResNet-50模型训练Flowers数据集 3. [train_flowers_vit.py](file:///home/damone/project/train-dl/train_flowers_vit.py) - 使用ViT模型训练Flowers数据集 ### 模型结构查看 1. [print_resnet50.py](file:///home/damone/project/train-dl/print_resnet50.py) - 打印ResNet50模型结构 ## 训练过程可视化 所有训练脚本都会记录以下关键指标: - 训练损失(Train Loss) - 训练准确率(Train Accuracy) - 测试准确率(Test Accuracy) 这些指标会被实时绘制在同一张图表上,并保存为PNG图像和NPZ数据文件。 ## 使用方法 ### 环境准备 确保已安装以下依赖: ``` torch torchvision matplotlib numpy ``` ### 运行训练 1. 训练MNIST数据集(ResNet-18): ``` python train_mnist_resnet18.py ``` 2. 训练MNIST数据集(ResNet-50): ``` python train_mnist_resnet50.py ``` 3. 训练MNIST数据集(ViT): ``` python train_mnist_vit.py ``` 4. 训练Flowers数据集(ResNet-18): ``` python train_flowers_resnet18.py ``` 5. 训练Flowers数据集(ResNet-50): ``` python train_flowers_resnet50.py ``` 6. 训练Flowers数据集(ViT): ``` python train_flowers_vit.py ``` ### 查看模型结构 ``` # 查看ResNet50模型结构 python print_resnet50.py # 查看ViT模型结构 python print_vit.py # 查看TinyViT模型结构 cd TinyViT python -c "from models import tiny_vit; model = tiny_vit.tiny_vit_21m_224(); print(model)" ``` ### 查看训练结果 训练完成后,结果将保存在对应的目录中: - MNIST ResNet-18结果保存在[resnet18_mnist/](file:///home/damone/project/train-dl/resnet18_mnist/)目录 - MNIST ResNet-50结果保存在[resnet50_mnist/](file:///home/damone/project/train-dl/resnet50_mnist/)目录 - Flowers ResNet-18结果保存在[resnet18_flowers/](file:///home/damone/project/train-dl/resnet18_flowers/)目录 - Flowers ResNet-50结果保存在[resnet50_flowers/](file:///home/damone/project/train-dl/resnet50_flowers/)目录 - MNIST ViT结果保存在[vit_mnist/](file:///home/damone/project/train-dl/vit_mnist/)目录 - Flowers ViT结果保存在[vit_flowers/](file:///home/damone/project/train-dl/vit_flowers/)目录 ## 数据集说明 ### MNIST 标准的手写数字识别数据集,包含0-9共10个数字类别。 ### Flowers 花朵图像数据集,包含5个类别: - Daisy(雏菊) - Dandelion(蒲公英) - Roses(玫瑰) - Sunflowers(向日葵) - Tulips(郁金香) ## 模型保存与加载 训练完成后,模型权重将保存为`.pth`文件。可以使用以下代码加载模型: ```python import torch import resnet # 加载MNIST模型 (ResNet-18) net = resnet.resnet18(num_classes=10) net.load_state_dict(torch.load('resnet18_mnist/resnet18_mnist.pth')) # 加载MNIST模型 (ResNet-50) net = resnet.resnet50(num_classes=10) net.load_state_dict(torch.load('resnet50_mnist/resnet50_mnist.pth')) # 加载Flowers模型 (ResNet-18) net = resnet.resnet18_cifar10(num_classes=5) # 注意:这里使用resnet18_cifar10 net.load_state_dict(torch.load('resnet18_flowers/resnet18_flowers.pth')) # 加载Flowers模型 (ResNet-50) net = resnet.resnet50_rgb(num_classes=5) net.load_state_dict(torch.load('resnet50_flowers/resnet50_flowers.pth')) ``` TinyViT模型加载示例: ```python # 加载TinyViT模型 import sys sys.path.append('TinyViT') from models import tiny_vit # 创建TinyViT-21M模型 model = tiny_vit.tiny_vit_21m_224(pretrained=True) output = model(image) ``` ## 训练指标可视化 训练指标保存为NPZ格式,可以直接查看PNG图表文件,也可以使用可视化脚本重新生成图表: ```bash cd resnet18_mnist python plot_metrics.py ``` 或者手动加载并绘制: ```python import numpy as np import matplotlib.pyplot as plt # 加载训练指标 data = np.load('resnet18_mnist/training_metrics.npz') train_losses = data['train_losses'] train_accs = data['train_accs'] test_accs = data['test_accs'] # 绘制指标 epochs = range(1, len(train_losses) + 1) plt.figure(figsize=(10, 6)) plt.plot(epochs, train_losses, label='Train Loss') plt.plot(epochs, train_accs, label='Train Accuracy') plt.plot(epochs, test_accs, label='Test Accuracy') plt.legend() plt.grid(True) plt.show() ``` ## ViT 训练优化 ViT 模型训练相比传统 CNN 模型更加不稳定,为了解决这个问题,我们引入了学习率调节机制: 1. 使用 `ReduceLROnPlateau` 调度器,当训练损失停止减少时自动降低学习率 2. 根据验证集性能动态调整学习率,提高训练稳定性 3. 针对不同数据集调整 patience 参数,以适应不同数据集的学习特性 4. 在训练过程中监控梯度,防止梯度爆炸 这些优化措施显著提高了 ViT 模型在两个数据集上的训练稳定性。 ## 使用方法 ### 运行训练 1. 训练MNIST数据集(ResNet-18): ``` python train.py ``` 2. 训练MNIST数据集(ResNet-50): ``` python train_mnist_resnet50.py ``` 3. 训练MNIST数据集(ViT): ``` python train_mnist_vit.py ``` 4. 训练Flowers数据集(ResNet-18): ``` python train_flowers.py ``` 5. 训练Flowers数据集(ResNet-50): ``` python train_flowers_resnet50.py ``` 6. 训练Flowers数据集(ViT): ``` python train_flowers_vit.py ``` ### 查看模型结构 ``` # 查看ResNet50模型结构 python print_resnet50.py ``` ## 注意事项 1. 由于深度学习模型较大,建议在GPU上运行以提高训练速度 2. 对于内存有限的设备,可以适当减小batch size或图像尺寸 3. 训练时间较长,建议在后台运行或使用GPU加速 4. 如果遇到CUDA内存不足错误,可以尝试减小batch size或使用较小的图像尺寸 5. TinyViT模型训练需要大规模数据集(如ImageNet-1k或ImageNet-22k)以获得最佳性能 6. TinyViT支持分布式训练,可以使用多GPU加速训练过程