跳到内容

配置

IOMixinProto

基类: Protocol

NeMo 中 IOMixin 类的 get/set hparam 函数的协议。

源代码在 bionemo/llm/model/config.py
117
118
119
120
121
122
123
124
125
126
class IOMixinProto(Protocol):
    """A Protocol for the get/set hparam functions of the IOMixin class from NeMo."""

    def set_hparam(self, attribute: str, value: Any, also_change_value: bool = True) -> None:
        """Set the value of an attribute in the config attached to the class by the IOMixin."""
        ...

    def get_hparam(self, attribute: str) -> Any:
        """Get the value of an attribute in the config attached to the class by the IOMixin."""
        ...

get_hparam(attribute)

获取 IOMixin 附加到类的配置中属性的值。

源代码在 bionemo/llm/model/config.py
124
125
126
def get_hparam(self, attribute: str) -> Any:
    """Get the value of an attribute in the config attached to the class by the IOMixin."""
    ...

set_hparam(attribute, value, also_change_value=True)

设置 IOMixin 附加到类的配置中属性的值。

源代码在 bionemo/llm/model/config.py
120
121
122
def set_hparam(self, attribute: str, value: Any, also_change_value: bool = True) -> None:
    """Set the value of an attribute in the config attached to the class by the IOMixin."""
    ...

MegatronBioNeMoModelConfig

基类: BionemoModelConfig[MegatronModelType], TransformerConfig, WillHaveGetSetHparam

用于 bionemo 的 ModelConfig 类,支持与 Megatron 模型一起使用,例如 NeMo2 所要求的。

源代码在 bionemo/llm/model/config.py
53
54
55
56
class MegatronBioNeMoModelConfig(BionemoModelConfig[MegatronModelType], TransformerConfig, iom.WillHaveGetSetHparam):
    """A ModelConfig class for bionemo that supports usage with Megatron models, for example as NeMo2 requires."""

    model_cls: Type[MegatronModelType]

MegatronBioNeMoTrainableModelConfig dataclass

基类: MegatronBioNeMoModelConfig[MegatronModelType], BionemoTrainableModelConfig[MegatronModelType, MegatronLossType], Generic[MegatronModelType, MegatronLossType]

用于 bionemo 的 TrainableModelConfig 类,支持与 Megatron 模型一起使用,例如 NeMo2 所要求的。

源代码在 bionemo/llm/model/config.py
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
@dataclass
class MegatronBioNeMoTrainableModelConfig(
    MegatronBioNeMoModelConfig[MegatronModelType],
    BionemoTrainableModelConfig[MegatronModelType, MegatronLossType],
    Generic[MegatronModelType, MegatronLossType],
):
    """A TrainableModelConfig class for bionemo that supports usage with Megatron models, for example as NeMo2 requires."""

    initial_ckpt_path: str | None = None
    initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=list)
    override_parent_fields: List[str] = field(default_factory=lambda: _OVERRIDE_BIONEMO_CONFIG_DEFAULTS)

    def load_settings_from_checkpoint(self, initial_ckpt_path: str) -> None:
        """Load settings into self from the checkpoint saved in self.

        Any setting in self.override_parent_fields is not overriden. Note that this function will also update the hyper
        parameters in this config, as well as the associated attributes in self in case they were modified post-init.

        Args:
            initial_ckpt_path: The path to the checkpoint to load, note that everything is loaded from this checkpoint
                other than the settings in self.override_parent_fields.

        Returns:
            None, the settings are loaded into self in place, and the hyper-parameters that will later be saved into
                a checkpoint are updated.
        """
        logger.warning(f"Loading {self.initial_ckpt_path}")
        # 1. get the config from the trainer io context by querying the `model.config` subpath of the trainer.
        initial_config: MegatronBioNeMoTrainableModelConfig = io.load_context(
            path=Path(initial_ckpt_path) / "context", subpath="model.config"
        )  # type: ignore
        initial_fields = {f.name for f in fields(initial_config)}
        my_fields = [f.name for f in fields(self)]
        skip_fields = set(self.override_parent_fields)
        override_fields = [f for f in my_fields if f in initial_fields and f not in skip_fields]
        override_mutate_possibly_extra_mutated_fiddle(self, initial_config, override_fields)

    def update_model_from_checkpoint(self, model: MegatronModelType, initial_ckpt_path: str) -> None:
        """Utility function to standardize how to load a megatron model from a checkpoint ignoring user-specified keys.

        Update the model with the weights from the provided checkpoint path, skipping the keys with the prefixes in
            self.initial_ckpt_skip_keys_with_these_prefixes.

        Args:
            model: The Megatron model to update.
            initial_ckpt_path: The path to the megatron checkpoint to load.

        Returns:
            None, the model is updated in place, supporting megatron model parallelism abstractions, and ignoring
                any extra keys that are provided in self.initial_ckpt_skip_keys_with_these_prefixes.
        """
        load_weights_sharded_inplace_nemo2_to_mcore(
            model=model,  # type: ignore
            distributed_checkpoint_dir=initial_ckpt_path,
            skip_keys_with_these_prefixes=set(self.initial_ckpt_skip_keys_with_these_prefixes),
        )

load_settings_from_checkpoint(initial_ckpt_path)

从保存在 self 中的检查点加载设置到 self 中。

self.override_parent_fields 中的任何设置都不会被覆盖。请注意,此函数还将更新此配置中的超参数,以及 self 中的关联属性,以防它们在初始化后被修改。

参数

名称 类型 描述 默认值
initial_ckpt_path str

要加载的检查点路径,请注意,除了 self.override_parent_fields 中的设置外,所有内容都从此检查点加载。

必需

返回

类型 描述
None

None,设置就地加载到 self 中,并且更新稍后将保存到检查点中的超参数。

源代码在 bionemo/llm/model/config.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def load_settings_from_checkpoint(self, initial_ckpt_path: str) -> None:
    """Load settings into self from the checkpoint saved in self.

    Any setting in self.override_parent_fields is not overriden. Note that this function will also update the hyper
    parameters in this config, as well as the associated attributes in self in case they were modified post-init.

    Args:
        initial_ckpt_path: The path to the checkpoint to load, note that everything is loaded from this checkpoint
            other than the settings in self.override_parent_fields.

    Returns:
        None, the settings are loaded into self in place, and the hyper-parameters that will later be saved into
            a checkpoint are updated.
    """
    logger.warning(f"Loading {self.initial_ckpt_path}")
    # 1. get the config from the trainer io context by querying the `model.config` subpath of the trainer.
    initial_config: MegatronBioNeMoTrainableModelConfig = io.load_context(
        path=Path(initial_ckpt_path) / "context", subpath="model.config"
    )  # type: ignore
    initial_fields = {f.name for f in fields(initial_config)}
    my_fields = [f.name for f in fields(self)]
    skip_fields = set(self.override_parent_fields)
    override_fields = [f for f in my_fields if f in initial_fields and f not in skip_fields]
    override_mutate_possibly_extra_mutated_fiddle(self, initial_config, override_fields)

update_model_from_checkpoint(model, initial_ckpt_path)

标准化如何从检查点加载 megatron 模型,忽略用户指定的键的实用函数。

使用提供的检查点路径中的权重更新模型,跳过前缀在 self.initial_ckpt_skip_keys_with_these_prefixes 中的键。

参数

名称 类型 描述 默认值
model MegatronModelType

要更新的 Megatron 模型。

必需
initial_ckpt_path str

要加载的 megatron 检查点路径。

必需

返回

类型 描述
None

None,模型就地更新,支持 megatron 模型并行抽象,并忽略 self.initial_ckpt_skip_keys_with_these_prefixes 中提供的任何额外键。

源代码在 bionemo/llm/model/config.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def update_model_from_checkpoint(self, model: MegatronModelType, initial_ckpt_path: str) -> None:
    """Utility function to standardize how to load a megatron model from a checkpoint ignoring user-specified keys.

    Update the model with the weights from the provided checkpoint path, skipping the keys with the prefixes in
        self.initial_ckpt_skip_keys_with_these_prefixes.

    Args:
        model: The Megatron model to update.
        initial_ckpt_path: The path to the megatron checkpoint to load.

    Returns:
        None, the model is updated in place, supporting megatron model parallelism abstractions, and ignoring
            any extra keys that are provided in self.initial_ckpt_skip_keys_with_these_prefixes.
    """
    load_weights_sharded_inplace_nemo2_to_mcore(
        model=model,  # type: ignore
        distributed_checkpoint_dir=initial_ckpt_path,
        skip_keys_with_these_prefixes=set(self.initial_ckpt_skip_keys_with_these_prefixes),
    )

override_mutate_possibly_extra_mutated_fiddle(target_cfg, source_cfg, maybe_mutated_elements_to_clone)

使用给定元素的源配置的值覆盖目标配置的值。

这将修改跟踪的初始化超参数值,以及修改 self 中的关联属性,以防它们稍后被 post_init 代码修改。

参数

名称 类型 描述 默认值
target_cfg IOMixinProto

要更新的配置。

必需
source_cfg IOMixinProto

要从中复制值的配置。

必需
maybe_mutated_elements_to_clone List[str]

要从源配置复制到目标配置的元素列表。

必需

返回

类型 描述
None

None,目标配置就地更新。

源代码在 bionemo/llm/model/config.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def override_mutate_possibly_extra_mutated_fiddle(
    target_cfg: IOMixinProto, source_cfg: IOMixinProto, maybe_mutated_elements_to_clone: List[str]
) -> None:
    """Override the values of the target config with the values of the source config for the given elements.

    This will modify the tracked init hyper-parameter values, as well as modifying the associated attributes in
        self incase they were modified later by post_init code.

    Args:
        target_cfg: The config to update.
        source_cfg: The config to copy values from.
        maybe_mutated_elements_to_clone: The list of elements to copy from the source config to the target config.

    Returns:
        None, the target config is updated in place.
    """
    for f in maybe_mutated_elements_to_clone:
        # 1. Update the tracked config values. Note that the associated attribute in self may have been modified
        #  post-init, so we don't want to change the value in self here. We do that separately next.
        target_cfg.set_hparam(f, source_cfg.get_hparam(f), also_change_value=False)
        # 2. Update the lazily untracked values (if the same variable name is used post-init)
        setattr(target_cfg, f, getattr(source_cfg, f))