跳到内容

Callbacks

PredictionWriter

基类:BasePredictionWriter, Callback

一个回调,在训练期间以指定的间隔将预测写入磁盘。

源代码位于 bionemo/llm/utils/callbacks.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
class PredictionWriter(BasePredictionWriter, pl.Callback):
    """A callback that writes predictions to disk at specified intervals during training."""

    def __init__(self, output_dir: str | os.PathLike, write_interval: IntervalT):
        """Initializes the callback.

        Args:
            output_dir: The directory where predictions will be written.
            write_interval: The interval at which predictions will be written. (batch, epoch)

        """
        super().__init__(write_interval)
        self.output_dir = str(output_dir)

    def write_on_batch_end(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule,
        prediction: Any,
        batch_indices: Sequence[int],
        batch: Any,
        batch_idx: int,
        dataloader_idx: int,
    ) -> None:
        """Writes predictions to disk at the end of each batch.

        Args:
            trainer: The Trainer instance.
            pl_module: The LightningModule instance.
            prediction: The prediction made by the model.
            batch_indices: The indices of the batch.
            batch: The batch data.
            batch_idx: The index of the batch.
            dataloader_idx: The index of the dataloader.
        """
        # this will create N (num processes) files in `output_dir` each containing
        # the predictions of it's respective rank
        result_path = os.path.join(self.output_dir, f"predictions__rank_{trainer.global_rank}__batch_{batch_idx}.pt")

        # batch_indices is not captured due to a lightning bug when return_predictions = False
        # we use input IDs in the prediction to map the result to input
        torch.save(prediction, result_path)
        logging.info(f"Inference predictions are stored in {result_path}\n{prediction.keys()}")

    def write_on_epoch_end(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule,
        predictions: Any,
        batch_indices: Sequence[int],
    ) -> None:
        """Writes predictions to disk at the end of each epoch.

        Args:
            trainer: The Trainer instance.
            pl_module: The LightningModule instance.
            predictions: The predictions made by the model.
            batch_indices: The indices of the batch.
        """
        # this will create N (num processes) files in `output_dir` each containing
        # the predictions of it's respective rank
        result_path = os.path.join(self.output_dir, f"predictions__rank_{trainer.global_rank}.pt")

        # collate multiple batches / ignore empty ones
        prediction = batch_collator([item for item in predictions if item is not None])

        # batch_indices is not captured due to a lightning bug when return_predictions = False
        # we use input IDs in the prediction to map the result to input
        torch.save(prediction, result_path)
        logging.info(f"Inference predictions are stored in {result_path}\n{prediction.keys()}")

__init__(output_dir, write_interval)

初始化回调。

参数

名称 类型 描述 默认值
output_dir str | PathLike

将写入预测的目录。

必需
write_interval IntervalT

写入预测的间隔。(批次,epoch)

必需
源代码位于 bionemo/llm/utils/callbacks.py
34
35
36
37
38
39
40
41
42
43
def __init__(self, output_dir: str | os.PathLike, write_interval: IntervalT):
    """Initializes the callback.

    Args:
        output_dir: The directory where predictions will be written.
        write_interval: The interval at which predictions will be written. (batch, epoch)

    """
    super().__init__(write_interval)
    self.output_dir = str(output_dir)

write_on_batch_end(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx)

在每个批次结束时将预测写入磁盘。

参数

名称 类型 描述 默认值
trainer Trainer

Trainer 实例。

必需
pl_module LightningModule

LightningModule 实例。

必需
prediction Any

模型做出的预测。

必需
batch_indices Sequence[int]

批次的索引。

必需
batch Any

批次数据。

必需
batch_idx int

批次的索引。

必需
dataloader_idx int

数据加载器的索引。

必需
源代码位于 bionemo/llm/utils/callbacks.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def write_on_batch_end(
    self,
    trainer: pl.Trainer,
    pl_module: pl.LightningModule,
    prediction: Any,
    batch_indices: Sequence[int],
    batch: Any,
    batch_idx: int,
    dataloader_idx: int,
) -> None:
    """Writes predictions to disk at the end of each batch.

    Args:
        trainer: The Trainer instance.
        pl_module: The LightningModule instance.
        prediction: The prediction made by the model.
        batch_indices: The indices of the batch.
        batch: The batch data.
        batch_idx: The index of the batch.
        dataloader_idx: The index of the dataloader.
    """
    # this will create N (num processes) files in `output_dir` each containing
    # the predictions of it's respective rank
    result_path = os.path.join(self.output_dir, f"predictions__rank_{trainer.global_rank}__batch_{batch_idx}.pt")

    # batch_indices is not captured due to a lightning bug when return_predictions = False
    # we use input IDs in the prediction to map the result to input
    torch.save(prediction, result_path)
    logging.info(f"Inference predictions are stored in {result_path}\n{prediction.keys()}")

write_on_epoch_end(trainer, pl_module, predictions, batch_indices)

在每个 epoch 结束时将预测写入磁盘。

参数

名称 类型 描述 默认值
trainer Trainer

Trainer 实例。

必需
pl_module LightningModule

LightningModule 实例。

必需
predictions Any

模型做出的预测。

必需
batch_indices Sequence[int]

批次的索引。

必需
源代码位于 bionemo/llm/utils/callbacks.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def write_on_epoch_end(
    self,
    trainer: pl.Trainer,
    pl_module: pl.LightningModule,
    predictions: Any,
    batch_indices: Sequence[int],
) -> None:
    """Writes predictions to disk at the end of each epoch.

    Args:
        trainer: The Trainer instance.
        pl_module: The LightningModule instance.
        predictions: The predictions made by the model.
        batch_indices: The indices of the batch.
    """
    # this will create N (num processes) files in `output_dir` each containing
    # the predictions of it's respective rank
    result_path = os.path.join(self.output_dir, f"predictions__rank_{trainer.global_rank}.pt")

    # collate multiple batches / ignore empty ones
    prediction = batch_collator([item for item in predictions if item is not None])

    # batch_indices is not captured due to a lightning bug when return_predictions = False
    # we use input IDs in the prediction to map the result to input
    torch.save(prediction, result_path)
    logging.info(f"Inference predictions are stored in {result_path}\n{prediction.keys()}")