ImageNet 训练在 PyTorch 中的应用#

这实现了在 ImageNet 数据集上训练流行的模型架构,例如 ResNet、AlexNet 和 VGG。

此版本已修改为使用 DALI。它假设数据集是来自 ImageNet 数据集的原始 JPEG。它为 DALI 提供基于 CPU 和 GPU 的 pipeline - 使用 dali_cpu 开关来启用 CPU pipeline。对于重型 GPU 网络(如 RN50),基于 CPU 的 pipeline 更快,对于一些 CPU 是瓶颈的轻型网络(如 RN18),GPU pipeline 更快。此版本已修改为使用 APEx 中的 DistributedDataParallel 模块,而不是上游 PyTorch 中的模块。请从此处安装 APEx。

要运行,请使用以下命令

ln -s /path/to/train/jpeg/ train
ln -s /path/to/validation/jpeg/ val
torchrun --nproc_per_node=NUM_GPUS main.py -a resnet50 --dali_cpu --b 128 \
         --loss-scale 128.0 --workers 4 --lr=0.4 --fp16-mode ./

要求#

  • APEx - 可选(从 PyTorch 1.6 开始,它是上游的一部分,因此无需单独安装),fp16 模式或分布式(多 GPU)操作需要

  • 从源码安装 PyTorch,github 上的 PyTorch 的 main 分支

  • pip install -r requirements.txt

  • 下载 ImageNet 数据集并将验证图像移动到带标签的子文件夹

    • 为此,您可以使用以下脚本

训练#

要训练模型,请使用所需的模型架构和 ImageNet 数据集的路径运行docs/examples/use_cases/pytorch/resnet50/main.py

python main.py -a resnet18 [imagenet-folder with train and val folders]

默认的学习率计划从 0.1 开始,每 30 个 epoch 衰减 10 倍。这适用于 ResNet 和具有批归一化的模型,但对于 AlexNet 和 VGG 来说太高。对于 AlexNet 或 VGG,请使用 0.01 作为初始学习率

python main.py -a alexnet --lr 0.01 [imagenet-folder with train and val folders]

用法#

main.py [-h] [--arch ARCH] [-j N] [--epochs N] [--start-epoch N] [-b N] [--lr LR] [--momentum M] [--weight-decay W] [--print-freq N] [--resume PATH] [-e] [--pretrained] [--opt-level] DIR

PyTorch ImageNet Training

positional arguments:
DIR                         path(s) to dataset (if one path is provided, it is assumed to have subdirectories named "train" and "val"; alternatively, train and val paths can be specified directly by providing both paths as arguments)

optional arguments (for the full list please check `Apex ImageNet example
         <https://github.com/NVIDIA/apex/tree/master/examples/imagenet>`_)
-h, --help                  show this help message and exit
--arch ARCH, -a ARCH        model architecture: alexnet | resnet | resnet101
                            | resnet152 | resnet18 | resnet34 | resnet50 | vgg
                            | vgg11 | vgg11_bn | vgg13 | vgg13_bn | vgg16
                            | vgg16_bn | vgg19 | vgg19_bn (default: resnet18)
-j N, --workers N           number of data loading workers (default: 4)
--epochs N                  number of total epochs to run
--start-epoch N             manual epoch number (useful on restarts)
-b N, --batch-size N        mini-batch size (default: 256)
--lr LR, --learning-rate LR initial learning rate
--momentum M                momentum
--weight-decay W, --wd W    weight decay (default: 1e-4)
--print-freq N, -p N        print frequency (default: 10)
--resume PATH               path to latest checkpoint (default: none)
-e, --evaluate              evaluate model on validation set
--pretrained                use pre-trained model
--dali_cpu                  use CPU based pipeline for DALI, for heavy GPU
                            networks it may work better, for IO bottlenecked
                            one like RN18 GPU default should be faster
--disable_dali              turns off DALI and switches to the native PyTorch
                            data processing
--fp16-mode                 enables mixed precision mode