跳到内容

损失

BERTMLMLossWithReduction

基类:_Nemo2CompatibleLossReduceMixin, MegatronLossReduction

源代码位于 bionemo/llm/model/loss.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
class BERTMLMLossWithReduction(_Nemo2CompatibleLossReduceMixin, MegatronLossReduction):  # noqa: D101
    def __init__(
        self,
        validation_step: bool = False,
        val_drop_last: bool = True,
        send_train_output: bool = False,
        send_val_output: bool = True,
    ) -> None:
        """Initializes the Model class.

        Args:
            validation_step (bool, optional): Whether this object is being applied to the validation step. Defaults to False.
            val_drop_last (bool, optional): Whether the last batch is configured to be dropped during validation. Defaults to True.
            send_train_output (bool): Whether to return the model output in training. Defaults to False.
            send_val_output (bool, optional): Whether to return the model output in validation. Defaults to True.
            include_forward_output_for_metrics (bool): Some downstream metrics such as perplexity require this. It can be
                expensive to return however, so disable this if performance is a top consideration.
        """
        # TODO(@jomitchell): Track down how we handle test. This is a common pattern in NeMo2, but these parameters seem likely
        #  to change in the future.
        super().__init__()
        self.validation_step = validation_step
        self.val_drop_last = val_drop_last
        self.send_train_output = send_train_output
        self.send_val_output = send_val_output

    def forward(
        self, batch: Dict[str, Tensor], forward_out: Dict[str, Tensor]
    ) -> Tuple[Tensor, PerTokenLossDict | SameSizeLossDict | DataParallelGroupLossAndIO]:
        """Computes loss of `labels` in the batch vs `token_logits` in the forward output currently. In the future this will be extended
            to handle other loss types like sequence loss if it is present in the forward_out and batch.

        Args:
            batch (Dict[str, Tensor]): The batch of data. Each tensor should be of shape [batch_size, *, *],
                and match the corresponding dimension for that particular key in the batch output.
                For example, the "labels" and "token_logits" key should have a tensor of shape [batch_size, sequence_length].
            forward_out (Dict[str, Tensor]): The forward output from the model. Each tensor should be of shape [batch_size, *, *]

        Taken from:
        https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L951-L976 .
        """  # noqa: D205
        if "labels" not in batch:
            raise ValueError("Labels not provided in the batch. These are required for this loss computation.")

        train_step: bool = not self.validation_step
        # Determine if we need to capture/send forward output for downstream metrics, such as perplexity logging
        #  this is expensive so only do if necessary.
        send_forward_output: bool = (self.validation_step and self.send_val_output) or (
            train_step and self.send_train_output
        )

        if send_forward_output:
            forward_out_report = {
                k: v.detach().clone() if torch.is_tensor(v) else v for k, v in forward_out.items()
            }  # avoid impact from inplace operation on token_logits in unreduced_token_loss_fn
        else:
            forward_out_report = {}

        # NOTE: token_logits is [sequence, batch] but labels and other fiels, including the loss are [batch, sequence]
        unreduced_token_loss = unreduced_token_loss_fn(forward_out["token_logits"], batch["labels"])  # [b s]

        # TODO(@jstjohn) also handle different output keys, like the sequence loss.

        # compute loss
        cp_size = parallel_state.get_context_parallel_world_size()
        if cp_size == 1:
            # reduce the loss across the micro batch per valid token
            loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"])
        else:
            # reduce the loss across the micro batch per valid token.
            # TODO(@jomitchell): Figure out who defines "num_valid_tokens_in_ub" in the batch and document/understand this.
            #  This has something to do with context parallel, and there is probably a megatron or nemo function that adds this and
            #  other necessary keys to the batch. Thanks!
            loss_for_microbatch = masked_token_loss_context_parallel(
                unreduced_token_loss, batch["loss_mask"], batch["num_valid_tokens_in_ub"]
            )

        # If we do not drop the last partial batch of validation, we need to do fancy reduction handling to support
        #  reducing the loss across the data parallel group.
        if self.validation_step and not self.val_drop_last:
            num_valid_tokens_in_microbatch = batch["loss_mask"].sum()
            if loss_for_microbatch.isnan():
                # TODO(@jomitchell): Add a unit test for this. This is the case where there are no valid tokens in the microbatch for the loss
                #  to be computed over, so we expect a NaN loss (divide by zero for a mean) but we make this an expected and non-breaking case,
                #  re-defining it as a 0 loss. This is standard in NeMo/NeMo2.
                if batch["loss_mask"].count_nonzero() != 0:
                    raise ValueError("Got NaN loss with non-empty input")
                loss_sum_for_microbatch = torch.zeros_like(num_valid_tokens_in_microbatch)
            else:
                loss_sum_for_microbatch = (
                    num_valid_tokens_in_microbatch * loss_for_microbatch
                )  # sum over all valid tokens

            # In this case we need to store the loss sum as well as the number of valid tokens in the microbatch.
            loss_sum_and_microbatch_size_all_gpu = torch.cat(
                [
                    loss_sum_for_microbatch.clone().detach().view(1),
                    Tensor([num_valid_tokens_in_microbatch]).cuda().clone().detach(),
                ]
            )
            torch.distributed.all_reduce(
                loss_sum_and_microbatch_size_all_gpu,
                group=parallel_state.get_data_parallel_group(),
                op=torch.distributed.ReduceOp.SUM,
            )
            return loss_for_microbatch * cp_size, {
                "loss_sum_and_microbatch_size": loss_sum_and_microbatch_size_all_gpu
            }

        # average the losses across the data parallel group, but also return the unreduced loss
        reduced_loss = average_losses_across_data_parallel_group([loss_for_microbatch])
        if send_forward_output:
            return loss_for_microbatch * cp_size, {
                "avg": reduced_loss,
                "batch": batch,
                "forward_out": forward_out_report,
            }
        else:
            return loss_for_microbatch * cp_size, {"avg": reduced_loss}

__init__(validation_step=False, val_drop_last=True, send_train_output=False, send_val_output=True)

初始化 Model 类。

参数

名称 类型 描述 默认值
validation_step bool

此对象是否应用于验证步骤。默认为 False。

False
val_drop_last bool

是否配置为在验证期间删除最后一个批次。默认为 True。

True
send_train_output bool

是否在训练中返回模型输出。默认为 False。

False
send_val_output bool

是否在验证中返回模型输出。默认为 True。

True
include_forward_output_for_metrics bool

一些下游指标(如困惑度)需要此项。但是,返回此项可能代价高昂,因此如果性能是首要考虑因素,请禁用此项。

必需
源代码位于 bionemo/llm/model/loss.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def __init__(
    self,
    validation_step: bool = False,
    val_drop_last: bool = True,
    send_train_output: bool = False,
    send_val_output: bool = True,
) -> None:
    """Initializes the Model class.

    Args:
        validation_step (bool, optional): Whether this object is being applied to the validation step. Defaults to False.
        val_drop_last (bool, optional): Whether the last batch is configured to be dropped during validation. Defaults to True.
        send_train_output (bool): Whether to return the model output in training. Defaults to False.
        send_val_output (bool, optional): Whether to return the model output in validation. Defaults to True.
        include_forward_output_for_metrics (bool): Some downstream metrics such as perplexity require this. It can be
            expensive to return however, so disable this if performance is a top consideration.
    """
    # TODO(@jomitchell): Track down how we handle test. This is a common pattern in NeMo2, but these parameters seem likely
    #  to change in the future.
    super().__init__()
    self.validation_step = validation_step
    self.val_drop_last = val_drop_last
    self.send_train_output = send_train_output
    self.send_val_output = send_val_output

forward(batch, forward_out)

计算批次中 labels 与当前 forward 输出中的 token_logits 之间的损失。未来,如果 forward_out 和 batch 中存在其他损失类型(如序列损失),则会扩展以处理这些类型。

参数

名称 类型 描述 默认值
batch Dict[str, Tensor]

数据批次。每个张量的形状应为 [batch_size, , ],并与批次输出中该特定键的相应维度匹配。例如,“labels”和“token_logits”键应具有形状为 [batch_size, sequence_length] 的张量。

必需
forward_out Dict[str, Tensor]

来自模型的 forward 输出。每个张量的形状应为 [batch_size, , ]

必需

摘自:https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L951-L976 。

源代码位于 bionemo/llm/model/loss.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
def forward(
    self, batch: Dict[str, Tensor], forward_out: Dict[str, Tensor]
) -> Tuple[Tensor, PerTokenLossDict | SameSizeLossDict | DataParallelGroupLossAndIO]:
    """Computes loss of `labels` in the batch vs `token_logits` in the forward output currently. In the future this will be extended
        to handle other loss types like sequence loss if it is present in the forward_out and batch.

    Args:
        batch (Dict[str, Tensor]): The batch of data. Each tensor should be of shape [batch_size, *, *],
            and match the corresponding dimension for that particular key in the batch output.
            For example, the "labels" and "token_logits" key should have a tensor of shape [batch_size, sequence_length].
        forward_out (Dict[str, Tensor]): The forward output from the model. Each tensor should be of shape [batch_size, *, *]

    Taken from:
    https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L951-L976 .
    """  # noqa: D205
    if "labels" not in batch:
        raise ValueError("Labels not provided in the batch. These are required for this loss computation.")

    train_step: bool = not self.validation_step
    # Determine if we need to capture/send forward output for downstream metrics, such as perplexity logging
    #  this is expensive so only do if necessary.
    send_forward_output: bool = (self.validation_step and self.send_val_output) or (
        train_step and self.send_train_output
    )

    if send_forward_output:
        forward_out_report = {
            k: v.detach().clone() if torch.is_tensor(v) else v for k, v in forward_out.items()
        }  # avoid impact from inplace operation on token_logits in unreduced_token_loss_fn
    else:
        forward_out_report = {}

    # NOTE: token_logits is [sequence, batch] but labels and other fiels, including the loss are [batch, sequence]
    unreduced_token_loss = unreduced_token_loss_fn(forward_out["token_logits"], batch["labels"])  # [b s]

    # TODO(@jstjohn) also handle different output keys, like the sequence loss.

    # compute loss
    cp_size = parallel_state.get_context_parallel_world_size()
    if cp_size == 1:
        # reduce the loss across the micro batch per valid token
        loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"])
    else:
        # reduce the loss across the micro batch per valid token.
        # TODO(@jomitchell): Figure out who defines "num_valid_tokens_in_ub" in the batch and document/understand this.
        #  This has something to do with context parallel, and there is probably a megatron or nemo function that adds this and
        #  other necessary keys to the batch. Thanks!
        loss_for_microbatch = masked_token_loss_context_parallel(
            unreduced_token_loss, batch["loss_mask"], batch["num_valid_tokens_in_ub"]
        )

    # If we do not drop the last partial batch of validation, we need to do fancy reduction handling to support
    #  reducing the loss across the data parallel group.
    if self.validation_step and not self.val_drop_last:
        num_valid_tokens_in_microbatch = batch["loss_mask"].sum()
        if loss_for_microbatch.isnan():
            # TODO(@jomitchell): Add a unit test for this. This is the case where there are no valid tokens in the microbatch for the loss
            #  to be computed over, so we expect a NaN loss (divide by zero for a mean) but we make this an expected and non-breaking case,
            #  re-defining it as a 0 loss. This is standard in NeMo/NeMo2.
            if batch["loss_mask"].count_nonzero() != 0:
                raise ValueError("Got NaN loss with non-empty input")
            loss_sum_for_microbatch = torch.zeros_like(num_valid_tokens_in_microbatch)
        else:
            loss_sum_for_microbatch = (
                num_valid_tokens_in_microbatch * loss_for_microbatch
            )  # sum over all valid tokens

        # In this case we need to store the loss sum as well as the number of valid tokens in the microbatch.
        loss_sum_and_microbatch_size_all_gpu = torch.cat(
            [
                loss_sum_for_microbatch.clone().detach().view(1),
                Tensor([num_valid_tokens_in_microbatch]).cuda().clone().detach(),
            ]
        )
        torch.distributed.all_reduce(
            loss_sum_and_microbatch_size_all_gpu,
            group=parallel_state.get_data_parallel_group(),
            op=torch.distributed.ReduceOp.SUM,
        )
        return loss_for_microbatch * cp_size, {
            "loss_sum_and_microbatch_size": loss_sum_and_microbatch_size_all_gpu
        }

    # average the losses across the data parallel group, but also return the unreduced loss
    reduced_loss = average_losses_across_data_parallel_group([loss_for_microbatch])
    if send_forward_output:
        return loss_for_microbatch * cp_size, {
            "avg": reduced_loss,
            "batch": batch,
            "forward_out": forward_out_report,
        }
    else:
        return loss_for_microbatch * cp_size, {"avg": reduced_loss}

DataParallelGroupLossAndIO

基类:TypedDict

数据并行组的平均损失 + 原始批次和推理输出。

源代码位于 bionemo/llm/model/loss.py
57
58
59
60
61
62
class DataParallelGroupLossAndIO(TypedDict):
    """Average losses across the data parallel group + the original batch and inference output."""

    avg: Tensor
    batch: dict[str, Tensor]
    forward_out: dict[str, Tensor]

PerTokenLossDict

基类:TypedDict

损失的张量字典。

这是为批次中每个 token 计算的损失的返回类型,支持不同大小的微批次。

源代码位于 bionemo/llm/model/loss.py
39
40
41
42
43
44
45
class PerTokenLossDict(TypedDict):
    """Tensor dictionary for loss.

    This is the return type for a loss that is computed per token in the batch, supporting microbatches of varying sizes.
    """

    loss_sum_and_microbatch_size: Tensor

SameSizeLossDict

基类:TypedDict

损失的张量字典。

这是为整个批次计算的损失的返回类型,其中所有微批次的大小相同。

源代码位于 bionemo/llm/model/loss.py
48
49
50
51
52
53
54
class SameSizeLossDict(TypedDict):
    """Tensor dictionary for loss.

    This is the return type for a loss that is computed for the entire batch, where all microbatches are the same size.
    """

    avg: Tensor

unreduced_token_loss_fn(logits, labels, cross_entropy_loss_fusion=False)

计算给定 logits 和 labels 的未缩减 token 损失,不考虑损失掩码。

警告:此函数不应用损失掩码。此外,它对输入执行就地操作。

参数

名称 类型 描述 默认值
logits Tensor

形状为 [sequence_length, batch_size, num_classes] 的预测 logits。

必需
labels Tensor

形状为 [batch_size, sequence_length] 的真实 labels。

必需
cross_entropy_loss_fusion bool

如果为 True,则使用词汇表并行交叉熵的融合内核版本。通常应优先使用此项以提高速度,因为它将更多操作打包到 GPU 上的单个内核中。但是,一些用户在使用此方法时观察到训练稳定性降低。

False

返回

名称 类型 描述
Tensor Tensor

形状为 [batch_size, sequence_length] 的未缩减 token 损失。

源代码位于 bionemo/llm/model/loss.py
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
def unreduced_token_loss_fn(logits: Tensor, labels: Tensor, cross_entropy_loss_fusion: bool = False) -> Tensor:
    """Computes the unreduced token loss given the logits and labels without regard to the loss mask.

    WARNING: This function does not apply a loss mask. Also, it does inplace operation on the inputs.

    Args:
        logits (Tensor): The predicted logits of shape [sequence_length, batch_size, num_classes].
        labels (Tensor): The true labels of shape [batch_size, sequence_length].
        cross_entropy_loss_fusion (bool): If True, use the fused kernel version of vocab parallel cross entropy. This
            should generally be preferred for speed as it packs more operations into a single kernel on the GPU. However
            some users have observed reduced training stability when using this method.

    Returns:
        Tensor: The unreduced token loss of shape [batch_size, sequence_length].
    """
    labels = labels.transpose(0, 1).contiguous()  # [b, s] -> [s, b]
    if cross_entropy_loss_fusion:
        loss = fused_vocab_parallel_cross_entropy(logits, labels)
    else:
        loss = tensor_parallel.vocab_parallel_cross_entropy(logits, labels)
    # [s b] => [b, s]
    loss = loss.transpose(0, 1).contiguous()
    return loss