跳到内容

Lightning

BertBatch

Bases: BertBatchCore

BERT-like 模型推理的输入数据类型。

源代码在 bionemo/llm/model/biobert/lightning.py
78
79
80
81
class BertBatch(BertBatchCore, total=False):
    """Input datatype for inference with BERT-like models."""

    cu_seqlens: Tensor

BertBatchCore

Bases: TypedDict

BERT-like 模型推理的输入数据类型。

源代码在 bionemo/llm/model/biobert/lightning.py
66
67
68
69
70
class BertBatchCore(TypedDict):
    """Input datatype for inference with BERT-like models."""

    text: Tensor
    attention_mask: Tensor

BertModel

Bases: Protocol[DataT]

BERT-like 模型的接口。

源代码在 bionemo/llm/model/biobert/lightning.py
52
53
54
55
56
57
58
59
60
61
62
63
class BertModel(Protocol[DataT]):
    """Interface for BERT-like models."""

    def forward(
        self, input_ids: Tensor, attention_mask: Tensor, packed_seq_params: Optional[PackedSeqParams] = None
    ) -> DataT:
        """Inference for BERT-like models.

        Inference for BERT-like models require their tokenized inputs by IDs, an attention mask over the input,
        and the original sequence lengths if the sequences are packed into a dense batch.
        """
        ...

forward(input_ids, attention_mask, packed_seq_params=None)

BERT-like 模型的推理。

BERT-like 模型的推理需要其 token 化的 ID 输入、输入上的注意力掩码,以及原始序列长度(如果序列被打包到密集批次中)。

源代码在 bionemo/llm/model/biobert/lightning.py
55
56
57
58
59
60
61
62
63
def forward(
    self, input_ids: Tensor, attention_mask: Tensor, packed_seq_params: Optional[PackedSeqParams] = None
) -> DataT:
    """Inference for BERT-like models.

    Inference for BERT-like models require their tokenized inputs by IDs, an attention mask over the input,
    and the original sequence lengths if the sequences are packed into a dense batch.
    """
    ...

BioBertLightningModule

Bases: BionemoLightningModule

源代码在 bionemo/llm/model/biobert/lightning.py
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
class BioBertLightningModule(BionemoLightningModule):
    def __init__(
        self,
        *args,
        data_step_function: DataStepFunction = biobert_data_step,
        forward_step_function: ForwardStepFunction = bert_forward_step,
        **kwargs,
    ):
        """DEPRECATED! Please use BionemoLightningModule. This is here so we can load older checkpoints.
        This maps the old name `forward_step_function` to the new name `forward_step` and `data_step_function` to
        `data_step`.

        Args:
            *args: all args are passed through to BionemoLightningModule
            data_step_function (DataStepFunction, optional): The data step function. Defaults to biobert_data_step.
            forward_step_function (ForwardStepFunction, optional): The forward step function. Defaults to bert_forward_step.
            **kwargs: all other kwargs are passed through to BionemoLightningModule.
        """  # noqa: D205
        super().__init__(*args, forward_step=forward_step_function, data_step=data_step_function, **kwargs)

__init__(*args, data_step_function=biobert_data_step, forward_step_function=bert_forward_step, **kwargs)

已弃用!请使用 BionemoLightningModule。这里是为了我们可以加载旧的检查点。这会将旧名称 forward_step_function 映射到新名称 forward_step,并将 data_step_function 映射到 data_step

参数

名称 类型 描述 默认值
*args

所有 args 都传递给 BionemoLightningModule

()
data_step_function DataStepFunction

数据步骤函数。默认为 biobert_data_step。

biobert_data_step
forward_step_function ForwardStepFunction

前向步骤函数。默认为 bert_forward_step。

bert_forward_step
**kwargs

所有其他 kwargs 都传递给 BionemoLightningModule。

{}
源代码在 bionemo/llm/model/biobert/lightning.py
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
def __init__(
    self,
    *args,
    data_step_function: DataStepFunction = biobert_data_step,
    forward_step_function: ForwardStepFunction = bert_forward_step,
    **kwargs,
):
    """DEPRECATED! Please use BionemoLightningModule. This is here so we can load older checkpoints.
    This maps the old name `forward_step_function` to the new name `forward_step` and `data_step_function` to
    `data_step`.

    Args:
        *args: all args are passed through to BionemoLightningModule
        data_step_function (DataStepFunction, optional): The data step function. Defaults to biobert_data_step.
        forward_step_function (ForwardStepFunction, optional): The forward step function. Defaults to bert_forward_step.
        **kwargs: all other kwargs are passed through to BionemoLightningModule.
    """  # noqa: D205
    super().__init__(*args, forward_step=forward_step_function, data_step=data_step_function, **kwargs)

SequenceBatch

Bases: SequenceBatchCore

BERT-like 模型推理的输入数据类型。

源代码在 bionemo/llm/model/biobert/lightning.py
90
91
92
93
94
class SequenceBatch(SequenceBatchCore, total=False):
    """Input datatype for inference with BERT-like models."""

    cu_seqlens_argmin: Tensor
    max_seqlen: Tensor

SequenceBatchCore

Bases: TypedDict

BERT-like 模型推理的输入数据类型。

源代码在 bionemo/llm/model/biobert/lightning.py
84
85
86
87
class SequenceBatchCore(TypedDict):
    """Input datatype for inference with BERT-like models."""

    cu_seqlens: Tensor

bert_default_optimizer(model)

返回 BERT 模型的默认优化器。

参数

名称 类型 描述 默认值
model Module

BERT 模型。

required

返回

类型 描述
FusedAdam

为此 BERT 模块的参数初始化的默认优化器。

FusedAdam

使用 1e-4 的学习率和 1e-2 的权重衰减。

源代码在 bionemo/llm/model/biobert/lightning.py
185
186
187
188
189
190
191
192
193
194
195
def bert_default_optimizer(model: torch.nn.Module) -> FusedAdam:
    """Returns the default optimizer for the BERT model.

    Args:
        model: The BERT model.

    Returns:
        The default optimizer initialized for this BERT module's parameters.
        Uses a learning rate of 1e-4 and weight decay of 1e-2.
    """
    return FusedAdam(model.parameters(), lr=1e-4, weight_decay=0.01)

bert_forward_step(model, batch)

使用批次执行模型的前向传递,以实现 Megatron 兼容性。

这会将批次键子集化为模型前向传递实际使用的键,然后调用模型的前向传递。如果批次中定义了“cu_seqsens”,则还会将打包的序列参数传递给模型以提高前向传递效率。

源代码在 bionemo/llm/model/biobert/lightning.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def bert_forward_step(model: BertModel[DataT], batch: BertBatch) -> DataT:
    """Performs the model's forward pass using the batch, for Megatron compatibility.

    This subsets the batch keys to the ones actually used by forward pass of the model, and then calls the model's
    forward pass. if "cu_seqsens" are defined in the batch, then the packed sequence parameters are also passed to the
    model for forward pass efficiency.
    """
    if "cu_seqlens" in batch:
        forward_results = model.forward(
            input_ids=batch["text"],
            attention_mask=batch["attention_mask"],
            packed_seq_params=get_packed_seq_params(cast(SequenceBatch, batch)),
        )
    else:
        forward_results = model.forward(input_ids=batch["text"], attention_mask=batch["attention_mask"])
    # TODO support losses that also include the binary head, this means doing something more fancy than the one
    #      default GPT reduction function above MaskedTokenLossReduction()
    return forward_results

biobert_data_step(dataloader_iter)

预处理 GeneFormer 模型的数据批次,并从数据加载器迭代器中提取单个数据批次。仅必要的批次键被子集化并传递给模型的前向传递和损失前向传递,具体取决于阶段。TODO 记录 parallel_state 管道阶段如何工作。

参数

名称 类型 描述 默认值
dataloader_iter

数据加载器上的迭代器。

required

返回

名称 类型 描述
output Dict[str, Tensor]

此批次的字典,限制为相关键。

源代码在 bionemo/llm/model/biobert/lightning.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def biobert_data_step(dataloader_iter) -> Dict[str, Tensor]:
    """Preprocesses a batch of data for the GeneFormer model, and ingest a single batch of data from the dataloader iterator.
        only necessary batch keys are subsetted and passed to the model's forward pass, and the loss forward pass, depending on stage.
        TODO document how parallel_state pipeline stages work.

    Args:
        dataloader_iter: An iterator over the dataloader.

    Returns:
        output: A dictionary of this batch limiting to relevant keys.

    """  # noqa: D205
    # Based on: https://github.com/NVIDIA/Megatron-LM/blob/main/pretrain_gpt.py#L87
    # https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L828-L842

    batch = next(dataloader_iter)

    if isinstance(batch, tuple) and len(batch) == 3:
        _batch: dict = batch[0]
    else:
        _batch = batch

    required_keys = set()
    required_keys.add("attention_mask")
    if parallel_state.is_pipeline_first_stage():
        required_keys.add("text")
    if parallel_state.is_pipeline_last_stage():
        required_keys.update(("labels", "loss_mask", "types", "is_random"))
    # if self.get_attention_mask_from_fusion:
    #     required_keys.remove('attention_mask')

    _batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in _batch.items()}
    # slice batch along sequence dimension for context parallelism
    output = get_batch_on_this_context_parallel_rank(_batch)

    return output

biobert_lightning_module(config, optimizer=None, tokenizer=None, data_step=biobert_data_step, forward_step=bert_forward_step, model_transform=None, **model_construct_args)

用于 BioBert 派生模型的 pytorch lightning 模块。

此模块旨在与 Megatron-LM 策略和 nemo 2.0 约定一起使用。要更改您的损失,请传入不同的配置对象,该对象返回不同的损失减少类。要更改您的模型及其输出,请传入不同的配置对象,该对象返回不同的模型。除非您需要更改更高级别的逻辑,否则请勿修改此函数。您可能需要修改此文件底部的各种步骤和前向函数,以处理批次中的新键/不同键。将来,可能需要将其中一些函数重构到配置对象或其他位置,以便它们更接近模型定义。

源代码在 bionemo/llm/model/biobert/lightning.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def biobert_lightning_module(
    config: BioBertConfig[MegatronBioBertModel, MegatronLossReduction],
    optimizer: Optional[MegatronOptimizerModule] = None,
    tokenizer: Optional[TokenizerSpec | PreTrainedTokenizerBase] = None,
    data_step: DataStep = biobert_data_step,
    forward_step: ForwardStep = bert_forward_step,
    model_transform: Optional[Callable] = None,
    **model_construct_args,
) -> BionemoLightningModule[MegatronBioBertModel, MegatronLossReduction]:
    """A pytorch lightning module for BioBert-derived models.

    This module is designed to be used with the Megatron-LM strategy and nemo 2.0 conventions.
    To change your loss, pass in a different config object that returns a different loss reduction class.
    To change your model and what it outputs, pass in a different config object that returns a different model.
    Do not modify this function unless you need to change higher level logic. You may need to modify the various step
    and forward functions towards the bottom of this file to handle new/different keys in the batch. In the future some
    of those functions may need to be refactored out into the config object or a different place so that they live
    closer to the model definition.
    """
    return BionemoLightningModule(
        config=config,
        optimizer=optimizer if optimizer is not None else default_megatron_optimizer(),
        data_step=data_step,
        forward_step=forward_step,
        tokenizer=tokenizer,
        model_transform=model_transform,
        **model_construct_args,
    )

get_batch_on_this_context_parallel_rank(batch, in_place=True)

确保输入批次的格式适合上下文并行 rank。

如果上下文并行世界大小大于 1,则根据上下文并行 rank 修改批次数据。否则,批次按原样返回。

参数

名称 类型 描述 默认值
batch Dict[str, Tensor]

输入批次数据。

required
in_place bool

如果为 true,则输入将被修改。返回的字典是对输入的引用。否则,输入数据始终进行浅复制,并且此副本将被修改并返回。

True

返回

名称 类型 描述
dict Dict[str, Tensor]

基于上下文并行 rank 修改的批次数据。

源代码在 bionemo/llm/model/biobert/lightning.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
def get_batch_on_this_context_parallel_rank(batch: Dict[str, Tensor], in_place: bool = True) -> Dict[str, Tensor]:
    """Ensures that the input batch is in the right format for context parallel rank.

    Modifies the batch data based on the context parallel rank, if the context parallel world size is greater than 1.
    Otherwise, the batch is returned as-is.


    Args:
        batch: The input batch data.
        in_place: If true, then the input is mutated. The returned dict is a reference to the input.
                  Otherwise, the input data is always shallow-copied and this copy is modified and returned.

    Returns:
        dict: The modified batch data based on the context parallel rank.
    """
    if not in_place:
        batch: dict[str, Tensor] = dict(**batch)

    if cp_size := parallel_state.get_context_parallel_world_size() > 1:
        num_valid_tokens_in_ub: Tensor | None = None
        if "loss_mask" in batch and batch["loss_mask"] is not None:
            num_valid_tokens_in_ub = batch["loss_mask"].sum()

        cp_rank = parallel_state.get_context_parallel_rank()
        for key, val in batch.items():
            if val is not None:
                seq_dim = 1 if key != "attention_mask" else 2
                _val = val.view(
                    *val.shape[0:seq_dim],
                    2 * cp_size,
                    val.shape[seq_dim] // (2 * cp_size),
                    *val.shape[(seq_dim + 1) :],
                )
                index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda(
                    non_blocking=True
                )
                _val = _val.index_select(seq_dim, index)
                _val = _val.view(*val.shape[0:seq_dim], -1, *_val.shape[(seq_dim + 2) :])
                batch[key] = _val
        batch["num_valid_tokens_in_ub"] = num_valid_tokens_in_ub  # type: ignore

    return batch

get_packed_seq_params(batch)

获取给定批次的打包序列参数。

仅当批次中定义了 cu_seqlens 时,才应调用此函数。

参数

名称 类型 描述 默认值
batch SequenceBatch

要打包的输入批次。

required

返回

名称 类型 描述
PackedSeqParams PackedSeqParams

打包的序列参数,包含以下属性: - cu_seqlens_q (Tensor):查询的序列长度。 - cu_seqlens_kv (Tensor):键和值的序列长度。 - max_seqlen_q (Tensor, optional):查询的最大序列长度(可选)。 - max_seqlen_kv (Tensor, optional):键和值的最大序列长度(可选)。 - qkv_format (str):查询、键和值张量的格式。

源代码在 bionemo/llm/model/biobert/lightning.py
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
def get_packed_seq_params(batch: SequenceBatch) -> PackedSeqParams:
    """Get the packed sequence parameters for the given batch.

    This function should only be called if `cu_seqlens` is defined in the batch.

    Args:
        batch: The input batch to pack.

    Returns:
        PackedSeqParams: The packed sequence parameters containing the following attributes:
            - cu_seqlens_q (Tensor): The sequence lengths for query.
            - cu_seqlens_kv (Tensor): The sequence lengths for key and value.
            - max_seqlen_q (Tensor, optional): The maximum sequence length for query.
            - max_seqlen_kv (Tensor, optional): The maximum sequence length for key and value.
            - qkv_format (str): The format of query, key, and value tensors.

    """
    cu_seqlens = batch["cu_seqlens"].squeeze()  # remove batch size dimension (mbs=1)
    # remove -1 "paddings" added in collate_fn
    if cu_seqlens_argmin := batch.get("cu_seqlens_argmin", None) is not None:
        # pre-compute cu_seqlens_argmin in dataset class for perf
        cu_seqlens = cu_seqlens[: cu_seqlens_argmin.item()]
    else:
        cu_seqlens = cu_seqlens[: torch.argmin(cu_seqlens)]

    # pre-compute max_seqlens in dataset class for perf
    max_seqlen = batch["max_seqlen"].squeeze() if "max_seqlen" in batch else None

    # these args are passed eventually into TEDotProductAttention.forward()
    return PackedSeqParams(
        cu_seqlens_q=cu_seqlens,
        cu_seqlens_kv=cu_seqlens,
        max_seqlen_q=max_seqlen,
        max_seqlen_kv=max_seqlen,
        qkv_format="thd",
    )