跳到内容

Datamodule

MegatronDataModule

基类:LightningDataModule

一个混入类,为 NeMo 中的数据模块训练恢复添加了 state_dictload_state_dict 方法。

源代码位于 bionemo/llm/data/datamodule.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class MegatronDataModule(pl.LightningDataModule):
    """A mixin that adds a `state_dict` and `load_state_dict` method for datamodule training resumption in NeMo."""

    def __init__(self, *args, **kwargs):
        """Set init_global_step to 0 for datamodule resumption."""
        super().__init__(*args, **kwargs)
        self.init_global_step = 0

    def update_init_global_step(self):
        """Please always call this when you get a new dataloader... if you forget, your resumption will not work."""
        self.init_global_step = self.trainer.global_step  # Update the init_global_step whenever we re-init training
        self.data_sampler.init_global_step = (
            self.init_global_step
        )  # Update the init_global_step whenever we re-init training

    def state_dict(self) -> Dict[str, Any]:
        """Called when saving a checkpoint, implement to generate and save datamodule state.

        Returns:
            A dictionary containing datamodule state.

        """
        consumed_samples = self.data_sampler.compute_consumed_samples(self.trainer.global_step - self.init_global_step)
        return {"consumed_samples": consumed_samples}

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        """Called when loading a checkpoint, implement to reload datamodule state given datamodule stat.

        Args:
            state_dict: the datamodule state returned by ``state_dict``.

        """
        try:
            from megatron.core.num_microbatches_calculator import update_num_microbatches

        except (ImportError, ModuleNotFoundError):
            logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
            from apex.transformer.pipeline_parallel.utils import update_num_microbatches

        consumed_samples = state_dict["consumed_samples"]
        self.data_sampler.init_consumed_samples = consumed_samples
        self.data_sampler.prev_consumed_samples = consumed_samples

        update_num_microbatches(
            consumed_samples=consumed_samples,
            consistency_check=False,
        )
        self.data_sampler.if_first_step = 1

__init__(*args, **kwargs)

将 init_global_step 设置为 0 以进行数据模块恢复。

源代码位于 bionemo/llm/data/datamodule.py
26
27
28
29
def __init__(self, *args, **kwargs):
    """Set init_global_step to 0 for datamodule resumption."""
    super().__init__(*args, **kwargs)
    self.init_global_step = 0

load_state_dict(state_dict)

在加载检查点时调用,实现以重新加载给定数据模块状态的数据模块状态。

参数

名称 类型 描述 默认值
state_dict Dict[str, Any]

state_dict 返回的数据模块状态。

必需
源代码位于 bionemo/llm/data/datamodule.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
    """Called when loading a checkpoint, implement to reload datamodule state given datamodule stat.

    Args:
        state_dict: the datamodule state returned by ``state_dict``.

    """
    try:
        from megatron.core.num_microbatches_calculator import update_num_microbatches

    except (ImportError, ModuleNotFoundError):
        logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
        from apex.transformer.pipeline_parallel.utils import update_num_microbatches

    consumed_samples = state_dict["consumed_samples"]
    self.data_sampler.init_consumed_samples = consumed_samples
    self.data_sampler.prev_consumed_samples = consumed_samples

    update_num_microbatches(
        consumed_samples=consumed_samples,
        consistency_check=False,
    )
    self.data_sampler.if_first_step = 1

state_dict()

在保存检查点时调用,实现以生成和保存数据模块状态。

返回

类型 描述
Dict[str, Any]

包含数据模块状态的字典。

源代码位于 bionemo/llm/data/datamodule.py
38
39
40
41
42
43
44
45
46
def state_dict(self) -> Dict[str, Any]:
    """Called when saving a checkpoint, implement to generate and save datamodule state.

    Returns:
        A dictionary containing datamodule state.

    """
    consumed_samples = self.data_sampler.compute_consumed_samples(self.trainer.global_step - self.init_global_step)
    return {"consumed_samples": consumed_samples}

update_init_global_step()

当您获得新的数据加载器时,请务必调用此函数... 如果您忘记了,您的恢复将无法工作。

源代码位于 bionemo/llm/data/datamodule.py
31
32
33
34
35
36
def update_init_global_step(self):
    """Please always call this when you get a new dataloader... if you forget, your resumption will not work."""
    self.init_global_step = self.trainer.global_step  # Update the init_global_step whenever we re-init training
    self.data_sampler.init_global_step = (
        self.init_global_step
    )  # Update the init_global_step whenever we re-init training