基类:LightningDataModule
一个混入类,为 NeMo 中的数据模块训练恢复添加了 state_dict
和 load_state_dict
方法。
源代码位于 bionemo/llm/data/datamodule.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70 | class MegatronDataModule(pl.LightningDataModule):
"""A mixin that adds a `state_dict` and `load_state_dict` method for datamodule training resumption in NeMo."""
def __init__(self, *args, **kwargs):
"""Set init_global_step to 0 for datamodule resumption."""
super().__init__(*args, **kwargs)
self.init_global_step = 0
def update_init_global_step(self):
"""Please always call this when you get a new dataloader... if you forget, your resumption will not work."""
self.init_global_step = self.trainer.global_step # Update the init_global_step whenever we re-init training
self.data_sampler.init_global_step = (
self.init_global_step
) # Update the init_global_step whenever we re-init training
def state_dict(self) -> Dict[str, Any]:
"""Called when saving a checkpoint, implement to generate and save datamodule state.
Returns:
A dictionary containing datamodule state.
"""
consumed_samples = self.data_sampler.compute_consumed_samples(self.trainer.global_step - self.init_global_step)
return {"consumed_samples": consumed_samples}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Called when loading a checkpoint, implement to reload datamodule state given datamodule stat.
Args:
state_dict: the datamodule state returned by ``state_dict``.
"""
try:
from megatron.core.num_microbatches_calculator import update_num_microbatches
except (ImportError, ModuleNotFoundError):
logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
from apex.transformer.pipeline_parallel.utils import update_num_microbatches
consumed_samples = state_dict["consumed_samples"]
self.data_sampler.init_consumed_samples = consumed_samples
self.data_sampler.prev_consumed_samples = consumed_samples
update_num_microbatches(
consumed_samples=consumed_samples,
consistency_check=False,
)
self.data_sampler.if_first_step = 1
|
__init__(*args, **kwargs)
将 init_global_step 设置为 0 以进行数据模块恢复。
源代码位于 bionemo/llm/data/datamodule.py
| def __init__(self, *args, **kwargs):
"""Set init_global_step to 0 for datamodule resumption."""
super().__init__(*args, **kwargs)
self.init_global_step = 0
|
load_state_dict(state_dict)
在加载检查点时调用,实现以重新加载给定数据模块状态的数据模块状态。
参数
名称 |
类型 |
描述 |
默认值 |
state_dict
|
Dict[str, Any]
|
|
必需
|
源代码位于 bionemo/llm/data/datamodule.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Called when loading a checkpoint, implement to reload datamodule state given datamodule stat.
Args:
state_dict: the datamodule state returned by ``state_dict``.
"""
try:
from megatron.core.num_microbatches_calculator import update_num_microbatches
except (ImportError, ModuleNotFoundError):
logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
from apex.transformer.pipeline_parallel.utils import update_num_microbatches
consumed_samples = state_dict["consumed_samples"]
self.data_sampler.init_consumed_samples = consumed_samples
self.data_sampler.prev_consumed_samples = consumed_samples
update_num_microbatches(
consumed_samples=consumed_samples,
consistency_check=False,
)
self.data_sampler.if_first_step = 1
|
state_dict()
在保存检查点时调用,实现以生成和保存数据模块状态。
返回
源代码位于 bionemo/llm/data/datamodule.py
38
39
40
41
42
43
44
45
46 | def state_dict(self) -> Dict[str, Any]:
"""Called when saving a checkpoint, implement to generate and save datamodule state.
Returns:
A dictionary containing datamodule state.
"""
consumed_samples = self.data_sampler.compute_consumed_samples(self.trainer.global_step - self.init_global_step)
return {"consumed_samples": consumed_samples}
|
update_init_global_step()
当您获得新的数据加载器时,请务必调用此函数... 如果您忘记了,您的恢复将无法工作。
源代码位于 bionemo/llm/data/datamodule.py
| def update_init_global_step(self):
"""Please always call this when you get a new dataloader... if you forget, your resumption will not work."""
self.init_global_step = self.trainer.global_step # Update the init_global_step whenever we re-init training
self.data_sampler.init_global_step = (
self.init_global_step
) # Update the init_global_step whenever we re-init training
|