训练
train_model(experiment_name, experiment_dir, config, data_module, n_steps_train, metric_tracker=None, tokenizer=get_tokenizer(), peft=None, _use_rich_model_summary=True)
使用 PyTorch Lightning 训练 BioNeMo ESM2 模型。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
experiment_name
|
str
|
实验的名称。 |
必需 |
experiment_dir
|
Path
|
实验将保存的目录。 |
必需 |
config
|
ESM2GenericConfig
|
ESM2 模型的配置。 |
必需 |
data_module
|
LightningDataModule
|
用于训练和验证的数据模块。 |
必需 |
n_steps_train
|
int
|
训练步骤的数量。 |
必需 |
metric_tracker
|
Callback | None
|
用于跟踪指标的可选回调 |
None
|
tokenizer
|
BioNeMoESMTokenizer
|
要使用的分词器。默认为 |
get_tokenizer()
|
peft
|
PEFT | None
|
PEFT(参数高效微调)模块。默认为 None。 |
None
|
_use_rich_model_summary
|
bool
|
是否使用 RichModelSummary 回调,在 https://nvbugspro.nvidia.com/bug/4959776 解决之前,我们的测试套件中省略了此回调。默认为 True。 |
True
|
返回值
类型 | 描述 |
---|---|
Path
|
一个元组,包含保存的检查点路径、一个 MetricTracker |
Callback | None
|
对象和 PyTorch Lightning Trainer 对象。 |
源代码位于 bionemo/esm2/model/finetune/train.py
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
|