跳到内容

微调 token 回归器

FineTuneSeqLenBioBertConfig dataclass

基类: BioBertConfig[MegatronBioBertFineTuneSeqLengthModel, SequenceLengthRMSEPlusBERTMLMLossWithReduction], IOMixinWithGettersSetters

BioBert 微调序列长度模型配置。

源代码在 bionemo/geneformer/model/finetune_token_regressor.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
@dataclass
class FineTuneSeqLenBioBertConfig(
    BioBertConfig[MegatronBioBertFineTuneSeqLengthModel, SequenceLengthRMSEPlusBERTMLMLossWithReduction],
    iom.IOMixinWithGettersSetters,
):
    """BioBert fine-tuning sequence length model configuration."""

    # When overriding fields in a dataclass _always_ declare types: https://github.com/python/cpython/issues/123269
    model_cls: Type[MegatronBioBertFineTuneSeqLengthModel] = MegatronBioBertFineTuneSeqLengthModel
    # typical case is fine-tune the base biobert that doesn't have this head. If you are instead loading a checkpoint
    # that has this new head and want to keep using these weights, please drop this next line or set to []
    initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=lambda: ["regression_head"])

    def get_loss_reduction_class(self) -> Type[SequenceLengthRMSEPlusBERTMLMLossWithReduction]:
        """Loss function type."""
        return SequenceLengthRMSEPlusBERTMLMLossWithReduction

get_loss_reduction_class()

损失函数类型。

源代码在 bionemo/geneformer/model/finetune_token_regressor.py
220
221
222
def get_loss_reduction_class(self) -> Type[SequenceLengthRMSEPlusBERTMLMLossWithReduction]:
    """Loss function type."""
    return SequenceLengthRMSEPlusBERTMLMLossWithReduction

LoRAForGeneFormerTokenRegressor

基类: LoRA

用于 Genformer Token 回归的 LoRA。

这里有一些棘手的事情要让一切正常工作

  1. 必须更新 Transformer 的冻结逻辑,以便不冻结新的 head 层。
  2. 必须更新 LoRA 适配器逻辑,以从传递的模块中提取要适配的层的输入/输出大小(以前的方法与 nn 和 TE 兼容,但不与 geneformer 当前使用的 megatron tensor_parallel 模块兼容)。此方法包含一个建议的重构,以使这些方法更通用和可扩展,并使用结构模式匹配。我们应该将此要求推给 NeMo,因为我们不应复制适配器方法。
  3. NeMo 中存在大量关于正在调用哪个模块以及它继承特定 mixin 的假设。如果从 megatron 模块更新为 torch 模块或其他模块,则可能会破坏这一点。出于这个原因,通常首选函数式调用,并且此处进行了一些函数式调用,以避免在整个代码库中更新继承。
源代码在 bionemo/geneformer/model/finetune_token_regressor.py
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
class LoRAForGeneFormerTokenRegressor(LoRA):
    """LoRA for Genformer Token Regression.

    There are a few tricky things here to get everything to work right:

    1. Freezing logic for the transformer has to be updated in order to not freeze the new head layers.
    2. The LoRA adapter logic has to be updated to pull the input/output sizes of the layers to be adapted from the
       modules that are passed (the previous method was compatible with nn and TE, but not megatrons tensor_parallel
       modules that are currently used by geneformer). This method contains a suggested refactor to make these methods
       a little more general and extensible with structural pattern matching as well. We should push this
       requirement onto NeMo, since we shouldn't duplicate the adapter method.
    3. There's a ton of assumptions in NeMo about which module is being called and that it inherits specific mixins.
       This could break the if it is updated from a megatron module to a torch module or something else. Functional
       calls are generally favored for this reason and some have been made here to avoid updating inheritance throughout
       the code base.
    """

    def input_size_getter(self, m: nn.Module) -> int:
        """Gets the input size of the supplied model."""
        match m:
            case object(input_size=n):
                return n
            case object(in_features=n):
                return n
            case _:
                raise ValueError(f"Module {m} does not have a supported input size calculation.")

    def output_size_getter(self, m: nn.Module) -> int:
        """Gets the output size of the supplied model."""
        match m:
            case object(output_size=n):
                return n
            case object(out_features=n):
                return n
            case _:
                raise ValueError(f"Module {m} does not have a supported output size calculation.")

    def __call__(self, model: nn.Module) -> nn.Module:
        """Inference."""
        fn.walk(model, self.selective_freeze)
        fn.walk(model, self.transform)
        return model

    def selective_freeze(self, m: nn.Module, name: str | None = None, prefix: str | None = None) -> nn.Module:
        """Freezes either 'encoder' or 'embedding' parameters of the input model (`m`) iff name is one of these."""
        if name in ["encoder", "embedding"]:
            FNMixin.freeze(m)
        return m

    def transform(self, m: nn.Module, name: str | None = None, prefix: str | None = None) -> nn.Module | LoRALinear:
        """Transforms the input model if the name is in the target modules."""
        tp_size = parallel_state.get_tensor_model_parallel_world_size()
        if name in self.target_modules:
            # m.in_features and m.out_features are divided by tp_size already,
            # but in_features and out_features passed to ParallelLinearAdapter are not.
            if prefix is not None and "regression_head" in prefix:
                return m
            if name in ["linear_qkv", "linear_fc1"]:
                # Column Parallel Linear
                input_is_parallel = False
                in_features = self.input_size_getter(
                    m
                )  # TODO(@georgea) note that this could break depending on the impl of `m`
                out_features = self.output_size_getter(m) * tp_size
                # LoRA is applied after layernorm, so layernorm output must be returned
                m.return_layernorm_output = True
                # perf optimization for LoRA + SP
                if m.config.sequence_parallel and not m.ub_overlap_ag:
                    m.return_layernorm_output_gathered = True
            else:  # name in ['linear_proj', 'linear_fc2']
                # Row Parallel Linear
                input_is_parallel = True
                in_features = (
                    self.input_size_getter(m) * tp_size
                )  # TODO(@georgea) note this could break depending on the impl of `m`
                out_features = self.output_size_getter(m)

            adapter = ParallelLinearAdapter(
                in_features,
                out_features,
                self.dim,
                activation="identity",
                norm_position=None,
                norm_type=None,
                column_init_method=self.lora_A_init_method,
                row_init_method=self.lora_B_init_method,
                gather_output=False,
                input_is_parallel=input_is_parallel,
                dropout=self.dropout,
                dropout_position=self.dropout_position,
                model_parallel_config=getattr(m, "config", None),
                alpha=self.alpha,
            )
            return LoRALinear(m, adapter)
        return m

__call__(model)

推理。

源代码在 bionemo/geneformer/model/finetune_token_regressor.py
262
263
264
265
266
def __call__(self, model: nn.Module) -> nn.Module:
    """Inference."""
    fn.walk(model, self.selective_freeze)
    fn.walk(model, self.transform)
    return model

input_size_getter(m)

获取所提供模型的输入大小。

源代码在 bionemo/geneformer/model/finetune_token_regressor.py
242
243
244
245
246
247
248
249
250
def input_size_getter(self, m: nn.Module) -> int:
    """Gets the input size of the supplied model."""
    match m:
        case object(input_size=n):
            return n
        case object(in_features=n):
            return n
        case _:
            raise ValueError(f"Module {m} does not have a supported input size calculation.")

output_size_getter(m)

获取所提供模型的输出大小。

源代码在 bionemo/geneformer/model/finetune_token_regressor.py
252
253
254
255
256
257
258
259
260
def output_size_getter(self, m: nn.Module) -> int:
    """Gets the output size of the supplied model."""
    match m:
        case object(output_size=n):
            return n
        case object(out_features=n):
            return n
        case _:
            raise ValueError(f"Module {m} does not have a supported output size calculation.")

selective_freeze(m, name=None, prefix=None)

如果名称是 'encoder' 或 'embedding' 之一,则冻结输入模型 (m) 的 'encoder' 或 'embedding' 参数。

源代码在 bionemo/geneformer/model/finetune_token_regressor.py
268
269
270
271
272
def selective_freeze(self, m: nn.Module, name: str | None = None, prefix: str | None = None) -> nn.Module:
    """Freezes either 'encoder' or 'embedding' parameters of the input model (`m`) iff name is one of these."""
    if name in ["encoder", "embedding"]:
        FNMixin.freeze(m)
    return m

transform(m, name=None, prefix=None)

如果名称在目标模块中,则转换输入模型。

源代码在 bionemo/geneformer/model/finetune_token_regressor.py
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
def transform(self, m: nn.Module, name: str | None = None, prefix: str | None = None) -> nn.Module | LoRALinear:
    """Transforms the input model if the name is in the target modules."""
    tp_size = parallel_state.get_tensor_model_parallel_world_size()
    if name in self.target_modules:
        # m.in_features and m.out_features are divided by tp_size already,
        # but in_features and out_features passed to ParallelLinearAdapter are not.
        if prefix is not None and "regression_head" in prefix:
            return m
        if name in ["linear_qkv", "linear_fc1"]:
            # Column Parallel Linear
            input_is_parallel = False
            in_features = self.input_size_getter(
                m
            )  # TODO(@georgea) note that this could break depending on the impl of `m`
            out_features = self.output_size_getter(m) * tp_size
            # LoRA is applied after layernorm, so layernorm output must be returned
            m.return_layernorm_output = True
            # perf optimization for LoRA + SP
            if m.config.sequence_parallel and not m.ub_overlap_ag:
                m.return_layernorm_output_gathered = True
        else:  # name in ['linear_proj', 'linear_fc2']
            # Row Parallel Linear
            input_is_parallel = True
            in_features = (
                self.input_size_getter(m) * tp_size
            )  # TODO(@georgea) note this could break depending on the impl of `m`
            out_features = self.output_size_getter(m)

        adapter = ParallelLinearAdapter(
            in_features,
            out_features,
            self.dim,
            activation="identity",
            norm_position=None,
            norm_type=None,
            column_init_method=self.lora_A_init_method,
            row_init_method=self.lora_B_init_method,
            gather_output=False,
            input_is_parallel=input_is_parallel,
            dropout=self.dropout,
            dropout_position=self.dropout_position,
            model_parallel_config=getattr(m, "config", None),
            alpha=self.alpha,
        )
        return LoRALinear(m, adapter)
    return m

MegatronBioBertFineTuneSeqLengthModel

基类: MegatronBioBertModel

用于生物 Bert 微调和序列长度的 Megatron 模型。

源代码在 bionemo/geneformer/model/finetune_token_regressor.py
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
class MegatronBioBertFineTuneSeqLengthModel(MegatronBioBertModel):
    """Megatron model for biobert finetuning with sequence length."""

    def __init__(self, config, *args, include_hiddens: bool = False, post_process: bool = True, **kwargs):
        """Constructor."""
        super().__init__(config, *args, include_hiddens=True, post_process=post_process, **kwargs)
        self.include_hiddens_finetuning = (
            include_hiddens  # this include_hiddens is for the final output of fine-tuning
        )
        # If post_process is True that means that we are at the last megatron parallelism stage and we can
        #   apply the head.
        if post_process:
            # if we are doing post process (eg pipeline last stage) then we need to add the output layers
            self.regression_head = MegatronRegressionMLPHead(config)

    def forward(self, *args, **kwargs) -> MegatronFineTuneOutput | BioBertOutput | Tensor:
        """Inference."""
        output: MegatronFineTuneOutput | BioBertOutput | Tensor = super().forward(*args, **kwargs)
        # Stop early if we are not in post_process mode (for example if we are in the middle of model parallelism)
        if not self.post_process:
            return output  # we are not at the last pipeline stage so just return what the parent has
        # Double check that the output from the parent has everything we need to do prediction in this head.
        if not isinstance(output, dict) or ("hidden_states" not in output):
            raise ValueError(
                f"Expected to find 'hidden_states' in the output, and output to be dictionary-like, found {output},\n"
                "Make sure include_hiddens=True in the call to super().__init__"
            )
        # Get the hidden state from the parent output, and pull out the [CLS] token for this task
        hidden_states: Tensor = output["hidden_states"][:, 0]  # [b s h] => [b h], use [CLS] (first) token for reg
        # Predict our 1d regression target
        regression_output = self.regression_head(hidden_states)
        if not self.include_hiddens_finetuning:
            del output["hidden_states"]
        output["regression_output"] = regression_output
        return output

__init__(config, *args, include_hiddens=False, post_process=True, **kwargs)

构造函数。

源代码在 bionemo/geneformer/model/finetune_token_regressor.py
173
174
175
176
177
178
179
180
181
182
183
def __init__(self, config, *args, include_hiddens: bool = False, post_process: bool = True, **kwargs):
    """Constructor."""
    super().__init__(config, *args, include_hiddens=True, post_process=post_process, **kwargs)
    self.include_hiddens_finetuning = (
        include_hiddens  # this include_hiddens is for the final output of fine-tuning
    )
    # If post_process is True that means that we are at the last megatron parallelism stage and we can
    #   apply the head.
    if post_process:
        # if we are doing post process (eg pipeline last stage) then we need to add the output layers
        self.regression_head = MegatronRegressionMLPHead(config)

forward(*args, **kwargs)

推理。

源代码在 bionemo/geneformer/model/finetune_token_regressor.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def forward(self, *args, **kwargs) -> MegatronFineTuneOutput | BioBertOutput | Tensor:
    """Inference."""
    output: MegatronFineTuneOutput | BioBertOutput | Tensor = super().forward(*args, **kwargs)
    # Stop early if we are not in post_process mode (for example if we are in the middle of model parallelism)
    if not self.post_process:
        return output  # we are not at the last pipeline stage so just return what the parent has
    # Double check that the output from the parent has everything we need to do prediction in this head.
    if not isinstance(output, dict) or ("hidden_states" not in output):
        raise ValueError(
            f"Expected to find 'hidden_states' in the output, and output to be dictionary-like, found {output},\n"
            "Make sure include_hiddens=True in the call to super().__init__"
        )
    # Get the hidden state from the parent output, and pull out the [CLS] token for this task
    hidden_states: Tensor = output["hidden_states"][:, 0]  # [b s h] => [b h], use [CLS] (first) token for reg
    # Predict our 1d regression target
    regression_output = self.regression_head(hidden_states)
    if not self.include_hiddens_finetuning:
        del output["hidden_states"]
    output["regression_output"] = regression_output
    return output

MegatronFineTuneOutput

基类: BioBertOutput

MegatronBioBertFineTuneSeqLengthModel 的推理输出类型。

源代码在 bionemo/geneformer/model/finetune_token_regressor.py
64
65
66
67
class MegatronFineTuneOutput(BioBertOutput):
    """Inference output type for MegatronBioBertFineTuneSeqLengthModel."""

    regression_output: Tensor

MegatronRegressionMLPHead

基类: MegatronModule

Megatron MLP head。

源代码在 bionemo/geneformer/model/finetune_token_regressor.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
class MegatronRegressionMLPHead(MegatronModule):
    """A megatron MLP head."""

    def __init__(self, config: TransformerConfig):
        """Constructor."""
        super().__init__(config)
        # FC layer over just the [CLS] token embedding
        # TODO use bias/activation fusion if requested
        self.linear_fc1 = nn.Linear(in_features=config.hidden_size, out_features=config.ffn_hidden_size)
        self.activation_function = config.activation_func
        self.linear_fc2 = nn.Linear(in_features=config.ffn_hidden_size, out_features=1)

    def forward(self, hidden_states: Tensor) -> Tensor:
        """Inference."""
        return self.linear_fc2(self.activation_function(self.linear_fc1(hidden_states)))

__init__(config)

构造函数。

源代码在 bionemo/geneformer/model/finetune_token_regressor.py
156
157
158
159
160
161
162
163
def __init__(self, config: TransformerConfig):
    """Constructor."""
    super().__init__(config)
    # FC layer over just the [CLS] token embedding
    # TODO use bias/activation fusion if requested
    self.linear_fc1 = nn.Linear(in_features=config.hidden_size, out_features=config.ffn_hidden_size)
    self.activation_function = config.activation_func
    self.linear_fc2 = nn.Linear(in_features=config.ffn_hidden_size, out_features=1)

forward(hidden_states)

推理。

源代码在 bionemo/geneformer/model/finetune_token_regressor.py
165
166
167
def forward(self, hidden_states: Tensor) -> Tensor:
    """Inference."""
    return self.linear_fc2(self.activation_function(self.linear_fc1(hidden_states)))

SequenceLengthRMSEPlusBERTMLMLossWithReduction

基类: BERTMLMLossWithReduction

损失函数。

源代码在 bionemo/geneformer/model/finetune_token_regressor.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
class SequenceLengthRMSEPlusBERTMLMLossWithReduction(BERTMLMLossWithReduction):
    """Loss function."""

    def forward(
        self,
        batch: SeqLenRmsepBatch,
        forward_out: Dict[str, Tensor],
    ) -> Tuple[Tensor, PerTokenLossDict | SameSizeLossDict | DataParallelGroupLossAndIO]:
        """Computes loss of `labels` in the batch vs `token_logits` in the forward output currently.

        In the future this will be extended to handle other loss types like sequence loss if it is present in the
        forward_out and batch.

        Args:
            batch: The batch of data. Each tensor should be of shape [batch_size, *, *],
                and match the corresponding dimension for that particular key in the batch output.
                For example, the "labels" and "token_logits" key should have a tensor of shape [batch_size, sequence_length].
            forward_out: The forward output from the model. Each tensor should be of shape [batch_size, *, *]

        Taken from:
        https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L951-L976
        """
        if "labels" not in batch:
            raise ValueError("Labels not provided in the batch. These are required for this loss computation.")

        unreduced_token_loss = unreduced_token_loss_fn(forward_out["token_logits"], batch["labels"])
        regression_output = forward_out["regression_output"]
        n_tokens = batch["attention_mask"].sum(dim=-1, keepdim=True).to(dtype=regression_output.dtype)
        assert len(n_tokens.shape) == 2
        assert n_tokens.shape[-1] == 1
        rmse_loss = torch.nn.functional.mse_loss(regression_output, n_tokens)

        # TODO(@jstjohn) also handle different output keys, like the sequence loss.

        cp_size = parallel_state.get_context_parallel_world_size()
        if cp_size == 1:
            # reduce the loss across the micro batch
            loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"])
        else:
            # reduce the loss across the micro batch.
            # TODO(@jomitchell): Figure out who defines "num_valid_tokens_in_ub" in the batch and document/understand this.
            #  This has something to do with context parallel, and there is probably a megatron or nemo function that adds this and
            #  other necessary keys to the batch. Thanks!
            loss_for_microbatch = masked_token_loss_context_parallel(
                unreduced_token_loss, batch["loss_mask"], batch["num_valid_tokens_in_ub"]
            )

        # If we do not drop the last partial batch of validation, we need to do fancy reduction handling to support
        #  reducing the loss across the data parallel group.
        if self.validation_step and not self.val_drop_last:
            num_valid_tokens_in_microbatch = batch["loss_mask"].sum()
            if loss_for_microbatch.isnan():
                # TODO(@jomitchell): Add a unit test for this. This is the case where there are no valid tokens in the microbatch for the loss
                #  to be computed over, so we expect a NaN loss (divide by zero for a mean) but we make this an expected and non-breaking case,
                #  re-defining it as a 0 loss. This is standard in NeMo/NeMo2.
                if batch["loss_mask"].count_nonzero() != 0:
                    raise ValueError("Got NaN loss with non-empty input")
                loss_sum_for_microbatch = torch.zeros_like(num_valid_tokens_in_microbatch)
            else:
                loss_sum_for_microbatch = num_valid_tokens_in_microbatch * loss_for_microbatch

            # In this case we need to store the loss sum as well as the number of valid tokens in the microbatch.
            loss_sum_and_microbatch_size_all_gpu = torch.cat(
                [
                    loss_sum_for_microbatch.clone().detach().view(1),
                    torch.tensor([num_valid_tokens_in_microbatch]).cuda().clone().detach(),
                ]
            )
            torch.distributed.all_reduce(
                loss_sum_and_microbatch_size_all_gpu, group=parallel_state.get_data_parallel_group()
            )
            return loss_for_microbatch * cp_size, {
                "loss_sum_and_microbatch_size": loss_sum_and_microbatch_size_all_gpu
            }
        loss_for_microbatch = loss_for_microbatch + rmse_loss  # add in the RMSE loss after reducing the logit loss
        # average the losses across the data parallel group, but also return the unreduced loss
        reduced_loss: Tensor = average_losses_across_data_parallel_group([loss_for_microbatch])
        if (self.validation_step and self.send_val_output) or (not self.validation_step and self.send_train_output):
            return loss_for_microbatch * cp_size, {"avg": reduced_loss, "batch": batch, "forward_out": forward_out}
        else:
            return loss_for_microbatch * cp_size, {"avg": reduced_loss}

forward(batch, forward_out)

计算批次中 labels 与前向输出中 token_logits 的损失。

将来,如果 forward_out 和 batch 中存在序列损失,则会将其扩展为处理其他损失类型。

参数

名称 类型 描述 默认值
batch SeqLenRmsepBatch

数据批次。每个张量的形状应为 [batch_size, , ],并与批次输出中该特定键的相应维度匹配。例如,“labels”和“token_logits”键应具有形状为 [batch_size, sequence_length] 的张量。

必需
forward_out Dict[str, Tensor]

来自模型的前向输出。每个张量的形状应为 [batch_size, , ]

必需

取自:https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L951-L976

源代码在 bionemo/geneformer/model/finetune_token_regressor.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def forward(
    self,
    batch: SeqLenRmsepBatch,
    forward_out: Dict[str, Tensor],
) -> Tuple[Tensor, PerTokenLossDict | SameSizeLossDict | DataParallelGroupLossAndIO]:
    """Computes loss of `labels` in the batch vs `token_logits` in the forward output currently.

    In the future this will be extended to handle other loss types like sequence loss if it is present in the
    forward_out and batch.

    Args:
        batch: The batch of data. Each tensor should be of shape [batch_size, *, *],
            and match the corresponding dimension for that particular key in the batch output.
            For example, the "labels" and "token_logits" key should have a tensor of shape [batch_size, sequence_length].
        forward_out: The forward output from the model. Each tensor should be of shape [batch_size, *, *]

    Taken from:
    https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L951-L976
    """
    if "labels" not in batch:
        raise ValueError("Labels not provided in the batch. These are required for this loss computation.")

    unreduced_token_loss = unreduced_token_loss_fn(forward_out["token_logits"], batch["labels"])
    regression_output = forward_out["regression_output"]
    n_tokens = batch["attention_mask"].sum(dim=-1, keepdim=True).to(dtype=regression_output.dtype)
    assert len(n_tokens.shape) == 2
    assert n_tokens.shape[-1] == 1
    rmse_loss = torch.nn.functional.mse_loss(regression_output, n_tokens)

    # TODO(@jstjohn) also handle different output keys, like the sequence loss.

    cp_size = parallel_state.get_context_parallel_world_size()
    if cp_size == 1:
        # reduce the loss across the micro batch
        loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"])
    else:
        # reduce the loss across the micro batch.
        # TODO(@jomitchell): Figure out who defines "num_valid_tokens_in_ub" in the batch and document/understand this.
        #  This has something to do with context parallel, and there is probably a megatron or nemo function that adds this and
        #  other necessary keys to the batch. Thanks!
        loss_for_microbatch = masked_token_loss_context_parallel(
            unreduced_token_loss, batch["loss_mask"], batch["num_valid_tokens_in_ub"]
        )

    # If we do not drop the last partial batch of validation, we need to do fancy reduction handling to support
    #  reducing the loss across the data parallel group.
    if self.validation_step and not self.val_drop_last:
        num_valid_tokens_in_microbatch = batch["loss_mask"].sum()
        if loss_for_microbatch.isnan():
            # TODO(@jomitchell): Add a unit test for this. This is the case where there are no valid tokens in the microbatch for the loss
            #  to be computed over, so we expect a NaN loss (divide by zero for a mean) but we make this an expected and non-breaking case,
            #  re-defining it as a 0 loss. This is standard in NeMo/NeMo2.
            if batch["loss_mask"].count_nonzero() != 0:
                raise ValueError("Got NaN loss with non-empty input")
            loss_sum_for_microbatch = torch.zeros_like(num_valid_tokens_in_microbatch)
        else:
            loss_sum_for_microbatch = num_valid_tokens_in_microbatch * loss_for_microbatch

        # In this case we need to store the loss sum as well as the number of valid tokens in the microbatch.
        loss_sum_and_microbatch_size_all_gpu = torch.cat(
            [
                loss_sum_for_microbatch.clone().detach().view(1),
                torch.tensor([num_valid_tokens_in_microbatch]).cuda().clone().detach(),
            ]
        )
        torch.distributed.all_reduce(
            loss_sum_and_microbatch_size_all_gpu, group=parallel_state.get_data_parallel_group()
        )
        return loss_for_microbatch * cp_size, {
            "loss_sum_and_microbatch_size": loss_sum_and_microbatch_size_all_gpu
        }
    loss_for_microbatch = loss_for_microbatch + rmse_loss  # add in the RMSE loss after reducing the logit loss
    # average the losses across the data parallel group, but also return the unreduced loss
    reduced_loss: Tensor = average_losses_across_data_parallel_group([loss_for_microbatch])
    if (self.validation_step and self.send_val_output) or (not self.validation_step and self.send_train_output):
        return loss_for_microbatch * cp_size, {"avg": reduced_loss, "batch": batch, "forward_out": forward_out}
    else:
        return loss_for_microbatch * cp_size, {"avg": reduced_loss}