PyTorch 中的 ImageNet 训练#
这实现了在 ImageNet 数据集上训练流行的模型架构,例如 ResNet、AlexNet 和 VGG。
此版本已修改为使用 DALI。它假设数据集是来自 ImageNet 数据集的原始 JPEG 图像。它为 DALI 提供基于 CPU 和 GPU 的流水线 - 使用 dali_cpu 开关启用 CPU 版本。对于重型 GPU 网络(如 RN50),基于 CPU 的版本更快,对于某些 CPU 成为瓶颈的轻型网络(如 RN18),GPU 版本更快。此版本已修改为使用 APEx 中的 DistributedDataParallel 模块,而不是上游 PyTorch 中的模块。请从 此处 安装 APEx。
要运行,请使用以下命令
ln -s /path/to/train/jpeg/ train
ln -s /path/to/validation/jpeg/ val
torchrun --nproc_per_node=NUM_GPUS main.py -a resnet50 --dali_cpu --b 128 \
--loss-scale 128.0 --workers 4 --lr=0.4 --fp16-mode ./
要求#
APEx - 可选(对于 PyTorch 1.6,它是上游的一部分,因此无需单独安装),fp16 模式或分布式(多 GPU)操作需要
从源代码安装 PyTorch,github 上的 PyTorch 的主分支
pip install -r requirements.txt
下载 ImageNet 数据集并将验证图像移动到带标签的子文件夹
为此,您可以使用以下 脚本
训练#
要训练模型,请使用所需的模型架构和 ImageNet 数据集的路径运行 docs/examples/use_cases/pytorch/resnet50/main.py
python main.py -a resnet18 [imagenet-folder with train and val folders]
默认学习率计划从 0.1 开始,每 30 个 epoch 衰减 10 倍。这适用于 ResNet 和带有批归一化的模型,但对于 AlexNet 和 VGG 来说太高了。对于 AlexNet 或 VGG,请使用 0.01 作为初始学习率
python main.py -a alexnet --lr 0.01 [imagenet-folder with train and val folders]
数据加载器#
dali:利用 DALI 流水线以及 DALI 的 PyTorch 迭代器进行数据加载、预处理和数据增强。
dali_proxy:使用 DALI 流水线进行预处理和数据增强,同时依赖 PyTorch 的数据加载器。DALI Proxy 方便将数据传输到 DALI 进行处理。请参阅 PyTorch DALI 代理。
pytorch:采用原生 PyTorch 数据加载器进行数据预处理和数据增强。