PyTorch 中的 ImageNet 训练#

这实现了在 ImageNet 数据集上训练流行的模型架构,例如 ResNet、AlexNet 和 VGG。

此版本已修改为使用 DALI。它假设数据集是来自 ImageNet 数据集的原始 JPEG 图像。它为 DALI 提供基于 CPU 和 GPU 的流水线 - 使用 dali_cpu 开关启用 CPU 版本。对于重型 GPU 网络(如 RN50),基于 CPU 的版本更快,对于某些 CPU 成为瓶颈的轻型网络(如 RN18),GPU 版本更快。此版本已修改为使用 APEx 中的 DistributedDataParallel 模块,而不是上游 PyTorch 中的模块。请从 此处 安装 APEx。

要运行,请使用以下命令

ln -s /path/to/train/jpeg/ train
ln -s /path/to/validation/jpeg/ val
torchrun --nproc_per_node=NUM_GPUS main.py -a resnet50 --dali_cpu --b 128 \
         --loss-scale 128.0 --workers 4 --lr=0.4 --fp16-mode ./

要求#

  • APEx - 可选(对于 PyTorch 1.6,它是上游的一部分,因此无需单独安装),fp16 模式或分布式(多 GPU)操作需要

  • 从源代码安装 PyTorch,github 上的 PyTorch 的主分支

  • pip install -r requirements.txt

  • 下载 ImageNet 数据集并将验证图像移动到带标签的子文件夹

    • 为此,您可以使用以下 脚本

训练#

要训练模型,请使用所需的模型架构和 ImageNet 数据集的路径运行 docs/examples/use_cases/pytorch/resnet50/main.py

python main.py -a resnet18 [imagenet-folder with train and val folders]

默认学习率计划从 0.1 开始,每 30 个 epoch 衰减 10 倍。这适用于 ResNet 和带有批归一化的模型,但对于 AlexNet 和 VGG 来说太高了。对于 AlexNet 或 VGG,请使用 0.01 作为初始学习率

python main.py -a alexnet --lr 0.01 [imagenet-folder with train and val folders]

数据加载器#

  • dali:利用 DALI 流水线以及 DALI 的 PyTorch 迭代器进行数据加载、预处理和数据增强。

  • dali_proxy:使用 DALI 流水线进行预处理和数据增强,同时依赖 PyTorch 的数据加载器。DALI Proxy 方便将数据传输到 DALI 进行处理。请参阅 PyTorch DALI 代理

  • pytorch:采用原生 PyTorch 数据加载器进行数据预处理和数据增强。

用法#