跳到内容

停止和前进

StopAndGoHarness

基类:ABC

用于测试中断训练和连续训练之间一致性的抽象基类。

用户应覆盖 cls.setup_model 并更新 cls.setup_class 以自定义下游测试用例。元数据通过回调收集,用户可以通过比较中断和连续用例的元数据来添加新的单元测试。

默认情况下,会比较学习率、全局步数、优化器状态、消耗的样本、输入和输出张量以及损失。用户可以通过向 cls.callbacks 添加新的回调和相关的测试函数来添加其他指标。

停止和前进测试的运作方式如下
  • 为简短的训练运行设置干净的模型,设置回调以进行跟踪。
  • 通过回调 Raise 中的 StopAndGoException 中断训练。
  • 使用同一组回调训练从检查点恢复的模型。
  • 连续训练模型,不中断,使用一组相同的新回调。
  • 比较每对中断和连续的回调,以检查是否相等。
实施此类时需要考虑的事项
  • 派生的测试名称应以 Test 开头,测试方法应以 test_ 开头,以启用 pytest 发现。
  • devices、pipeline_model_parallel 和 tensor_model_parallel 可能会影响 DataModule 的设置。某些数据集期望已知的全局批处理大小,这取决于设备数量和条件张量模型并行/流水线模型并行设置。默认情况下,我们仅在没有并行性的单个设备上进行测试。
  • 'mode' 在某些情况下很有用,但并非在所有情况下都如此。在有用时根据这些条件实施条件。例如,实施一个停止和恢复的测试可能很有用。
    • 更改回调以测试元数据完整性(停止和前进测试的核心功能)。
    • 更改模型构造以使用不同的超参数。
    • ... 等。上述每个测试用例对于自动化测试各种预期行为可能都很有用。
  • stop()、resume()、continuous() 或共同的 run_stop_and_go() 是提供的方法,它们执行实际测试,利用各种设置方法中的条件,并在必要时尊重 'mode'。

属性

名称 类型 描述
root_dir

根目录。

val_check_interval int

验证检查间隔。存储为属性以确保一致性。

exp_name str

实验名称。

extra_metrics_dict str

指标及其对应函数的字典。

另请参阅:bionemo.testing.callbacks。

源代码位于 bionemo/testing/harnesses/stop_and_go.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
class StopAndGoHarness(ABC):
    """Abstract base class for testing consistency between interrupted and continuous training.

    Users should override cls.setup_model and update cls.setup_class to customize the downstream test cases. Metadata
    are collected through callbacks and users can add new unit tests by comparing the metadata for the interrupted and
    continuous cases.

    By default, learning rate, global step, optimizer state, consumed samples, input and output tensors, and loss are
    compared. Users can add additional metrics by adding new callbacks to `cls.callbacks` and associated test functions.

    Stop and go tests act as follows:
        - setup a clean model for a brief training run, set callbacks to track.
        - interrupt training via the StopAndGoException in the callback Raise.
        - train the model resumed from the checkpoint with the same set of callbacks.
        - train the model continuously without interruption with a new set of the same callbacks.
        - compare each pair of interrupted and continuous callbacks to check for equality.

    Considerations when implementing this class:
        - The derived test name should start with `Test`, and test methods should start with `test_` to enable pytest
          discovery.
        - devices, pipeline_model_parallel, and tensor_model_parallel may impact the setup of DataModule. Certain
            datasets expect a known global batch size, which depends on the number of devices and conditional tensor
            model parallel/ pipeline model parallel settings. By default, we are testing only on single device without
            parallelism.
        - 'mode' is useful in some cases, but not in all cases. Implement conditions based on these when useful. As an
            example, it may be useful to implement a test that stops and resumes.
            - changing callbacks to test metadata integrity (core feature of stop-and-go tests).
            - changing the model construction to use different hyperparameters.
            - ... etc
            Each of the above tests cases may be useful for automated testing of various expected behavior.
        - stop(), resume(), continuous() or collectively run_stop_and_go() are provided methods which execute the actual
          tests, leveraging the conditions in the various setup methods, respecting 'mode' where necessary.

    Attributes:
        root_dir: The root directory.
        val_check_interval: The validation check interval. Stored as an attribute to ensure consistency.
        exp_name: The experiment name.
        extra_metrics_dict: A dictionary of metrics and their corresponding functions.

    See Also: bionemo.testing.callbacks.
    """

    # class variables that need to be overridden
    num_steps: int
    val_check_interval: int
    limit_val_batches: int
    lr: float = 1e-4
    precision: Literal["16-mixed", "bf16-mixed", "32"]
    output_tensor_atol: float = 1e-3  # Absolute tolerance for model precision between output tensors.
    output_tensor_rtol: float = 1e-4  # Relative tolerance for model precision between output tensors.

    # class variables that will be setup in setUpClass
    tempdir: tempfile.TemporaryDirectory
    metadata_dir: pathlib.Path
    exp_name: str
    callbacks: CallbackDict
    nemo_logger: NeMoLogger

    @classmethod
    def setup_class(cls) -> None:
        """Sets up the class by creating a temporary directory, metadata_dir, exp_name and callbacks."""
        cls.tempdir = tempfile.TemporaryDirectory()
        cls.metadata_dir = pathlib.Path(cls.tempdir.name) / "metadata"
        cls.exp_name = cls.__name__

        cls.callbacks = cls.get_default_callbacks()

        cls.nemo_logger = NeMoLogger(
            log_dir=cls.tempdir.name,
            name=cls.exp_name,
            use_datetime_version=False,
            version=None,
            tensorboard=None,
            wandb=None,
            ckpt=None,
        )

    @classmethod
    def teardown_class(cls) -> None:
        """Tears down the class by cleaning up the temporary directory."""
        cls.tempdir.cleanup()

    @classmethod
    @abstractmethod
    def setup_model(cls, mode: Mode) -> tuple[pl.LightningModule, pl.LightningDataModule, nl.MegatronOptimizerModule]:
        """Constructs the model, data, and optimizer for the test harness.

        Optionally supports separate code paths for 'stop'/'resume'/'continuous', although implementors are encouraged
        to use the same code path for both.

        Args:
            mode: The mode indicating whether to stop or go.

        Returns:
            tuple: A tuple containing the model, data, and optimizer.
        """
        raise NotImplementedError()

    @classmethod
    def setup_trainer(
        cls,
        mode: Mode,
    ) -> nl.Trainer:
        """Setup trainer by passing stop, resume, or continuous callbacks according to mode.

        Args:
            mode (Mode): The mode indicating whether to stop, resume, or train continuously.

        Returns:
            (nl.Trainer): NeMo Lightning trainer object.
        """
        strategy = MegatronStrategy(
            ddp="megatron",
            find_unused_parameters=True,
            ckpt_include_optimizer=True,
            ckpt_async_save=False,
        )

        trainer = nl.Trainer(
            devices=1,
            max_steps=cls.num_steps,
            accelerator="gpu",
            strategy=strategy,
            limit_val_batches=cls.limit_val_batches,
            val_check_interval=cls.val_check_interval,
            log_every_n_steps=cls.val_check_interval,
            num_nodes=1,
            callbacks=list(cls.callbacks[mode].values()),
            plugins=nl.MegatronMixedPrecision(precision=cls.precision),
        )
        return trainer

    @classmethod
    def get_default_callbacks(cls) -> CallbackDict:
        """Returns a list of callbacks based on the specified mode. Base implementation provides reasonable defaults.

        To extend this method, call the super and append to the callbacks, depending on which mode you are in:

        ```python
        callbacks = super().get_callbacks()
        callbacks[mode]["MyCustomCallback"] = MyCustomCallback()
        return callbacks
        ```

        Returns:
            A dictionary of callbacks based on the specified mode, each of which maps a callback name to a callback
            object.
        """
        callbacks: CallbackDict = {}

        def make_callbacks() -> Dict[Type[pl.Callback], pl.Callback]:
            return {
                testing_callbacks.LearningRateCallback: testing_callbacks.LearningRateCallback(),
                testing_callbacks.GlobalStepStateCallback: testing_callbacks.GlobalStepStateCallback(),
                testing_callbacks.ConsumedSamplesCallback: testing_callbacks.ConsumedSamplesCallback(),
                testing_callbacks.OptimizerStateCallback: testing_callbacks.OptimizerStateCallback(),
                testing_callbacks.TrainInputCallback: testing_callbacks.TrainInputCallback(),
                testing_callbacks.TrainOutputCallback: testing_callbacks.TrainOutputCallback(),
                testing_callbacks.TrainLossCallback: testing_callbacks.TrainLossCallback(),
                testing_callbacks.ValidInputCallback: testing_callbacks.ValidInputCallback(),
                testing_callbacks.ValidOutputCallback: testing_callbacks.ValidOutputCallback(),
                testing_callbacks.ValidLossCallback: testing_callbacks.ValidLossCallback(),
            }

        interrupted_callbacks = make_callbacks()
        callbacks[Mode.CONTINUOUS] = make_callbacks()

        for mode in [Mode.STOP, Mode.RESUME]:
            consumed_samples_cls = testing_callbacks.TrainValInitConsumedSamplesStopAndGoCallback
            callbacks[mode] = {
                consumed_samples_cls: consumed_samples_cls(mode=mode),
                **interrupted_callbacks,
            }

        callbacks[Mode.STOP].update(
            {
                testing_callbacks.StopAfterValidEpochEndCallback: testing_callbacks.StopAfterValidEpochEndCallback(),
                nl_callbacks.ModelCheckpoint: nl_callbacks.ModelCheckpoint(
                    save_last=True,
                    monitor="val_loss",
                    save_top_k=2,
                    always_save_context=True,
                    filename="{epoch}-{step}-{val_loss:.2f}",
                ),
            }
        )

        return callbacks

    # stop() and resume() are provided methods and run the requisite methods with the appropriate mode.
    @classmethod
    def stop(cls) -> None:
        """Runs pre-training and 'stops' after the first checkpoint is saved.

        This method sets up the model, data, and optimizer for the Mode.STOP mode.
        It then sets up the trainer and strategy for the Mode.STOP mode with the given metrics.
        The training process is executed using the `llm.train` function, passing the model, data, trainer, logger, optimizer, and resume options.
        If a `testing_callbacks.StopAndGoException` is raised during training, it is caught and no action is taken.

        Raises:
            testing_callbacks.StopAndGoException: If a stop and go exception occurs during training.
        """
        logging.info("Running stop()...")

        model, data, opt = cls.setup_model(mode=Mode.STOP)
        trainer = cls.setup_trainer(Mode.STOP)
        with distributed_model_parallel_state():
            llm.train(
                model=model,
                data=data,
                trainer=trainer,
                log=cls.nemo_logger,
                optim=opt,
                resume=resume.AutoResume(
                    resume_if_exists=False,  # Looks for the -last checkpoint to continue training.
                    resume_ignore_no_checkpoint=True,  # When false this will throw an error with no existing checkpoint.
                ),
            )

    @classmethod
    def resume(cls) -> None:
        """Resumes the model from the checkpoint saved at the end of `stop()` and verifies the metadata integrity."""
        logging.info("Running resume()...")

        model, data, opt = cls.setup_model(mode=Mode.RESUME)
        trainer = cls.setup_trainer(Mode.RESUME)
        with distributed_model_parallel_state():
            llm.train(
                model=model,
                data=data,
                trainer=trainer,
                log=cls.nemo_logger,
                optim=opt,
                resume=resume.AutoResume(
                    resume_if_exists=True,  # Looks for the -last checkpoint to continue training.
                    resume_ignore_no_checkpoint=False,  # When false this will throw an error with no existing checkpoint.
                ),
            )

    @classmethod
    def continuous(cls) -> None:
        """Trains the model in one continuous path without stopping."""
        logging.info("Running continuous()...")

        model, data, opt = cls.setup_model(mode=Mode.CONTINUOUS)
        trainer = cls.setup_trainer(Mode.CONTINUOUS)
        with distributed_model_parallel_state():
            llm.train(model=model, data=data, trainer=trainer, log=cls.nemo_logger, optim=opt)

    @classmethod
    def run_stop_and_go(cls):
        """Executes training both continuously and with a checkpoint interruption."""
        # Interrupted model training
        cls.stop()
        cls.resume()

        # Cleanup and reinitialize the temporary directory so we don't conflict with a previous checkpoint.
        cls.tempdir.cleanup()
        cls.tempdir = tempfile.TemporaryDirectory()

        # Continuous model training.
        cls.continuous()

    @pytest.mark.parametrize(
        "callback_type",
        [
            testing_callbacks.LearningRateCallback,
            testing_callbacks.GlobalStepStateCallback,
            testing_callbacks.ConsumedSamplesCallback,
            testing_callbacks.OptimizerStateCallback,
            testing_callbacks.TrainInputCallback,
            testing_callbacks.TrainOutputCallback,
            testing_callbacks.TrainLossCallback,
            testing_callbacks.ValidInputCallback,
            testing_callbacks.ValidOutputCallback,
            testing_callbacks.ValidLossCallback,
        ],
    )
    def test_stop_and_go_consistency(self, callback_type):
        """Tests the consistency of the callback data between the interrupted and continuous checks."""
        interrupted_callback = get_callback(self.callbacks, Mode.RESUME, callback_type)
        continuous_callback = get_callback(self.callbacks, Mode.CONTINUOUS, callback_type)
        assert interrupted_callback.data, f"No data found for {callback_type}"

        if callback_type in {testing_callbacks.TrainOutputCallback, testing_callbacks.ValidOutputCallback}:
            atol, rtol = self.output_tensor_atol, self.output_tensor_rtol
        else:
            atol, rtol = 1e-4, 1e-4

        recursive_assert_approx_equal(
            interrupted_callback.data,
            continuous_callback.data,
            atol=atol,
            rtol=rtol,
        )

    def test_train_val_init_consumed_samples(self):
        """Tests the initial consumed samples in stop-and-go scenario."""
        train_consumed_stop, val_consumed_stop = get_callback(
            self.callbacks, Mode.STOP, testing_callbacks.TrainValInitConsumedSamplesStopAndGoCallback
        ).data
        train_consumed_go, val_consumed_go = get_callback(
            self.callbacks, Mode.RESUME, testing_callbacks.TrainValInitConsumedSamplesStopAndGoCallback
        ).data

        assert val_consumed_stop == 0
        assert val_consumed_go == 0
        assert train_consumed_stop == 0
        assert train_consumed_go > 0

continuous() classmethod

在一个连续路径中训练模型,无需停止。

源代码位于 bionemo/testing/harnesses/stop_and_go.py
300
301
302
303
304
305
306
307
308
@classmethod
def continuous(cls) -> None:
    """Trains the model in one continuous path without stopping."""
    logging.info("Running continuous()...")

    model, data, opt = cls.setup_model(mode=Mode.CONTINUOUS)
    trainer = cls.setup_trainer(Mode.CONTINUOUS)
    with distributed_model_parallel_state():
        llm.train(model=model, data=data, trainer=trainer, log=cls.nemo_logger, optim=opt)

get_default_callbacks() classmethod

根据指定的模式返回回调列表。基本实现提供合理的默认值。

要扩展此方法,请调用 super 并根据您所处的模式附加到回调

callbacks = super().get_callbacks()
callbacks[mode]["MyCustomCallback"] = MyCustomCallback()
return callbacks

返回

类型 描述
CallbackDict

基于指定模式的回调字典,每个字典都将回调名称映射到回调

CallbackDict

对象。

源代码位于 bionemo/testing/harnesses/stop_and_go.py
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
@classmethod
def get_default_callbacks(cls) -> CallbackDict:
    """Returns a list of callbacks based on the specified mode. Base implementation provides reasonable defaults.

    To extend this method, call the super and append to the callbacks, depending on which mode you are in:

    ```python
    callbacks = super().get_callbacks()
    callbacks[mode]["MyCustomCallback"] = MyCustomCallback()
    return callbacks
    ```

    Returns:
        A dictionary of callbacks based on the specified mode, each of which maps a callback name to a callback
        object.
    """
    callbacks: CallbackDict = {}

    def make_callbacks() -> Dict[Type[pl.Callback], pl.Callback]:
        return {
            testing_callbacks.LearningRateCallback: testing_callbacks.LearningRateCallback(),
            testing_callbacks.GlobalStepStateCallback: testing_callbacks.GlobalStepStateCallback(),
            testing_callbacks.ConsumedSamplesCallback: testing_callbacks.ConsumedSamplesCallback(),
            testing_callbacks.OptimizerStateCallback: testing_callbacks.OptimizerStateCallback(),
            testing_callbacks.TrainInputCallback: testing_callbacks.TrainInputCallback(),
            testing_callbacks.TrainOutputCallback: testing_callbacks.TrainOutputCallback(),
            testing_callbacks.TrainLossCallback: testing_callbacks.TrainLossCallback(),
            testing_callbacks.ValidInputCallback: testing_callbacks.ValidInputCallback(),
            testing_callbacks.ValidOutputCallback: testing_callbacks.ValidOutputCallback(),
            testing_callbacks.ValidLossCallback: testing_callbacks.ValidLossCallback(),
        }

    interrupted_callbacks = make_callbacks()
    callbacks[Mode.CONTINUOUS] = make_callbacks()

    for mode in [Mode.STOP, Mode.RESUME]:
        consumed_samples_cls = testing_callbacks.TrainValInitConsumedSamplesStopAndGoCallback
        callbacks[mode] = {
            consumed_samples_cls: consumed_samples_cls(mode=mode),
            **interrupted_callbacks,
        }

    callbacks[Mode.STOP].update(
        {
            testing_callbacks.StopAfterValidEpochEndCallback: testing_callbacks.StopAfterValidEpochEndCallback(),
            nl_callbacks.ModelCheckpoint: nl_callbacks.ModelCheckpoint(
                save_last=True,
                monitor="val_loss",
                save_top_k=2,
                always_save_context=True,
                filename="{epoch}-{step}-{val_loss:.2f}",
            ),
        }
    )

    return callbacks

resume() classmethod

从在 stop() 结束时保存的检查点恢复模型,并验证元数据完整性。

源代码位于 bionemo/testing/harnesses/stop_and_go.py
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
@classmethod
def resume(cls) -> None:
    """Resumes the model from the checkpoint saved at the end of `stop()` and verifies the metadata integrity."""
    logging.info("Running resume()...")

    model, data, opt = cls.setup_model(mode=Mode.RESUME)
    trainer = cls.setup_trainer(Mode.RESUME)
    with distributed_model_parallel_state():
        llm.train(
            model=model,
            data=data,
            trainer=trainer,
            log=cls.nemo_logger,
            optim=opt,
            resume=resume.AutoResume(
                resume_if_exists=True,  # Looks for the -last checkpoint to continue training.
                resume_ignore_no_checkpoint=False,  # When false this will throw an error with no existing checkpoint.
            ),
        )

run_stop_and_go() classmethod

执行连续训练和具有检查点中断的训练。

源代码位于 bionemo/testing/harnesses/stop_and_go.py
310
311
312
313
314
315
316
317
318
319
320
321
322
@classmethod
def run_stop_and_go(cls):
    """Executes training both continuously and with a checkpoint interruption."""
    # Interrupted model training
    cls.stop()
    cls.resume()

    # Cleanup and reinitialize the temporary directory so we don't conflict with a previous checkpoint.
    cls.tempdir.cleanup()
    cls.tempdir = tempfile.TemporaryDirectory()

    # Continuous model training.
    cls.continuous()

setup_class() classmethod

通过创建临时目录、metadata_dir、exp_name 和回调来设置类。

源代码位于 bionemo/testing/harnesses/stop_and_go.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
@classmethod
def setup_class(cls) -> None:
    """Sets up the class by creating a temporary directory, metadata_dir, exp_name and callbacks."""
    cls.tempdir = tempfile.TemporaryDirectory()
    cls.metadata_dir = pathlib.Path(cls.tempdir.name) / "metadata"
    cls.exp_name = cls.__name__

    cls.callbacks = cls.get_default_callbacks()

    cls.nemo_logger = NeMoLogger(
        log_dir=cls.tempdir.name,
        name=cls.exp_name,
        use_datetime_version=False,
        version=None,
        tensorboard=None,
        wandb=None,
        ckpt=None,
    )

setup_model(mode) abstractmethod classmethod

为测试工具构建模型、数据和优化器。

可选地支持 'stop'/'resume'/'continuous' 的单独代码路径,尽管鼓励实施者对两者都使用相同的代码路径。

参数

名称 类型 描述 默认
mode Mode

指示是停止还是前进的模式。

必需

返回

名称 类型 描述
tuple tuple[LightningModule, LightningDataModule, MegatronOptimizerModule]

包含模型、数据和优化器的元组。

源代码位于 bionemo/testing/harnesses/stop_and_go.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
@classmethod
@abstractmethod
def setup_model(cls, mode: Mode) -> tuple[pl.LightningModule, pl.LightningDataModule, nl.MegatronOptimizerModule]:
    """Constructs the model, data, and optimizer for the test harness.

    Optionally supports separate code paths for 'stop'/'resume'/'continuous', although implementors are encouraged
    to use the same code path for both.

    Args:
        mode: The mode indicating whether to stop or go.

    Returns:
        tuple: A tuple containing the model, data, and optimizer.
    """
    raise NotImplementedError()

setup_trainer(mode) classmethod

通过根据模式传递 stop、resume 或 continuous 回调来设置训练器。

参数

名称 类型 描述 默认
mode Mode

指示是停止、恢复还是连续训练的模式。

必需

返回

类型 描述
Trainer

NeMo Lightning 训练器对象。

源代码位于 bionemo/testing/harnesses/stop_and_go.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
@classmethod
def setup_trainer(
    cls,
    mode: Mode,
) -> nl.Trainer:
    """Setup trainer by passing stop, resume, or continuous callbacks according to mode.

    Args:
        mode (Mode): The mode indicating whether to stop, resume, or train continuously.

    Returns:
        (nl.Trainer): NeMo Lightning trainer object.
    """
    strategy = MegatronStrategy(
        ddp="megatron",
        find_unused_parameters=True,
        ckpt_include_optimizer=True,
        ckpt_async_save=False,
    )

    trainer = nl.Trainer(
        devices=1,
        max_steps=cls.num_steps,
        accelerator="gpu",
        strategy=strategy,
        limit_val_batches=cls.limit_val_batches,
        val_check_interval=cls.val_check_interval,
        log_every_n_steps=cls.val_check_interval,
        num_nodes=1,
        callbacks=list(cls.callbacks[mode].values()),
        plugins=nl.MegatronMixedPrecision(precision=cls.precision),
    )
    return trainer

stop() classmethod

运行预训练并在保存第一个检查点后“停止”。

此方法为 Mode.STOP 模式设置模型、数据和优化器。然后,它使用给定的指标为 Mode.STOP 模式设置训练器和策略。训练过程使用 llm.train 函数执行,传递模型、数据、训练器、记录器、优化器和恢复选项。如果在训练期间引发 testing_callbacks.StopAndGoException,则会捕获它并且不采取任何操作。

引发

类型 描述
StopAndGoException

如果在训练期间发生停止和前进异常。

源代码位于 bionemo/testing/harnesses/stop_and_go.py
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
@classmethod
def stop(cls) -> None:
    """Runs pre-training and 'stops' after the first checkpoint is saved.

    This method sets up the model, data, and optimizer for the Mode.STOP mode.
    It then sets up the trainer and strategy for the Mode.STOP mode with the given metrics.
    The training process is executed using the `llm.train` function, passing the model, data, trainer, logger, optimizer, and resume options.
    If a `testing_callbacks.StopAndGoException` is raised during training, it is caught and no action is taken.

    Raises:
        testing_callbacks.StopAndGoException: If a stop and go exception occurs during training.
    """
    logging.info("Running stop()...")

    model, data, opt = cls.setup_model(mode=Mode.STOP)
    trainer = cls.setup_trainer(Mode.STOP)
    with distributed_model_parallel_state():
        llm.train(
            model=model,
            data=data,
            trainer=trainer,
            log=cls.nemo_logger,
            optim=opt,
            resume=resume.AutoResume(
                resume_if_exists=False,  # Looks for the -last checkpoint to continue training.
                resume_ignore_no_checkpoint=True,  # When false this will throw an error with no existing checkpoint.
            ),
        )

teardown_class() classmethod

通过清理临时目录来拆卸类。

源代码位于 bionemo/testing/harnesses/stop_and_go.py
138
139
140
141
@classmethod
def teardown_class(cls) -> None:
    """Tears down the class by cleaning up the temporary directory."""
    cls.tempdir.cleanup()

test_stop_and_go_consistency(callback_type)

测试中断和连续检查之间回调数据的一致性。

源代码位于 bionemo/testing/harnesses/stop_and_go.py
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
@pytest.mark.parametrize(
    "callback_type",
    [
        testing_callbacks.LearningRateCallback,
        testing_callbacks.GlobalStepStateCallback,
        testing_callbacks.ConsumedSamplesCallback,
        testing_callbacks.OptimizerStateCallback,
        testing_callbacks.TrainInputCallback,
        testing_callbacks.TrainOutputCallback,
        testing_callbacks.TrainLossCallback,
        testing_callbacks.ValidInputCallback,
        testing_callbacks.ValidOutputCallback,
        testing_callbacks.ValidLossCallback,
    ],
)
def test_stop_and_go_consistency(self, callback_type):
    """Tests the consistency of the callback data between the interrupted and continuous checks."""
    interrupted_callback = get_callback(self.callbacks, Mode.RESUME, callback_type)
    continuous_callback = get_callback(self.callbacks, Mode.CONTINUOUS, callback_type)
    assert interrupted_callback.data, f"No data found for {callback_type}"

    if callback_type in {testing_callbacks.TrainOutputCallback, testing_callbacks.ValidOutputCallback}:
        atol, rtol = self.output_tensor_atol, self.output_tensor_rtol
    else:
        atol, rtol = 1e-4, 1e-4

    recursive_assert_approx_equal(
        interrupted_callback.data,
        continuous_callback.data,
        atol=atol,
        rtol=rtol,
    )

test_train_val_init_consumed_samples()

测试停止和前进场景中的初始消耗样本。

源代码位于 bionemo/testing/harnesses/stop_and_go.py
357
358
359
360
361
362
363
364
365
366
367
368
369
def test_train_val_init_consumed_samples(self):
    """Tests the initial consumed samples in stop-and-go scenario."""
    train_consumed_stop, val_consumed_stop = get_callback(
        self.callbacks, Mode.STOP, testing_callbacks.TrainValInitConsumedSamplesStopAndGoCallback
    ).data
    train_consumed_go, val_consumed_go = get_callback(
        self.callbacks, Mode.RESUME, testing_callbacks.TrainValInitConsumedSamplesStopAndGoCallback
    ).data

    assert val_consumed_stop == 0
    assert val_consumed_go == 0
    assert train_consumed_stop == 0
    assert train_consumed_go > 0

get_callback(callbacks, mode, callback_type)

返回具有给定名称和模式的回调。

使类型提示更方便的便捷函数。

参数

名称 类型 描述 默认
callbacks CallbackDict

回调字典。

必需
mode Mode

指示是停止还是前进的模式。

必需
callback_type Type[Callback]

回调的类型。

必需

返回

类型 描述
Callback

pl.Callback:具有给定名称和模式的回调。

源代码位于 bionemo/testing/harnesses/stop_and_go.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def get_callback(callbacks: CallbackDict, mode: Mode, callback_type: Type[Callback]) -> Callback:
    """Returns the callback with the given name and mode.

    Convenience function to make type hinting easier.

    Args:
        callbacks: The dictionary of callbacks.
        mode: The mode indicating whether to stop or go.
        callback_type: The type of the callback.

    Returns:
        pl.Callback: The callback with the given name and mode.
    """
    return callbacks[mode][callback_type]  # type: ignore