Callbacks
PredictionWriter
基类:BasePredictionWriter
, Callback
一个回调,在训练期间以指定的间隔将预测写入磁盘。
源代码位于 bionemo/llm/utils/callbacks.py
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
|
__init__(output_dir, write_interval)
初始化回调。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
output_dir
|
str | PathLike
|
将写入预测的目录。 |
必需 |
write_interval
|
IntervalT
|
写入预测的间隔。(批次,epoch) |
必需 |
源代码位于 bionemo/llm/utils/callbacks.py
34 35 36 37 38 39 40 41 42 43 |
|
write_on_batch_end(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx)
在每个批次结束时将预测写入磁盘。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
trainer
|
Trainer
|
Trainer 实例。 |
必需 |
pl_module
|
LightningModule
|
LightningModule 实例。 |
必需 |
prediction
|
Any
|
模型做出的预测。 |
必需 |
batch_indices
|
Sequence[int]
|
批次的索引。 |
必需 |
batch
|
Any
|
批次数据。 |
必需 |
batch_idx
|
int
|
批次的索引。 |
必需 |
dataloader_idx
|
int
|
数据加载器的索引。 |
必需 |
源代码位于 bionemo/llm/utils/callbacks.py
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
|
write_on_epoch_end(trainer, pl_module, predictions, batch_indices)
在每个 epoch 结束时将预测写入磁盘。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
trainer
|
Trainer
|
Trainer 实例。 |
必需 |
pl_module
|
LightningModule
|
LightningModule 实例。 |
必需 |
predictions
|
Any
|
模型做出的预测。 |
必需 |
batch_indices
|
Sequence[int]
|
批次的索引。 |
必需 |
源代码位于 bionemo/llm/utils/callbacks.py
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
|