跳到内容

权重工具

load_weights_sharded_inplace_nemo2_to_mcore(model, distributed_checkpoint_dir, skip_keys_with_these_prefixes)

给定一个 Megatron 模块,此函数将根据并行/分布式状态确定要加载的权重键/子集。此操作假设检查点由 nemo2 训练器保存,该训练器将 module. 前缀放在所有键名称上,但我们随后将直接加载到没有 module. 前缀的 Megatron 模块中。请注意,如果有任何额外的键您不想在检查点中搜索,例如,如果您在模块上添加了新的层/头,则需要在模型中提供这些键的前缀路径,它们将被忽略。后一个功能对于灵活的微调策略至关重要,在这种策略中,您可以从具有部分重叠结构的其他模型中部分加载权重。

参数

名称 类型 描述 默认值
model MegatronModelType

您要将权重加载到的 Megatron 模型。

必需
distributed_checkpoint_dir str | Path

描述

必需
skip_keys_with_these_prefixes Set[str]

描述

必需
源代码位于 bionemo/llm/utils/weight_utils.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def load_weights_sharded_inplace_nemo2_to_mcore(
    model: MegatronModelType, distributed_checkpoint_dir: str | Path, skip_keys_with_these_prefixes: Set[str]
) -> None:
    """Given a megatron module, this function will determine which keys/subsets of weights to load given the
        parallel/distributed state. This operates assuming a checkpoint was saved by a nemo2 trainer which places
        the `module.` prefix on all key names, but we are then going to load directly in to the megatron module
        without the `module.` prefix. Note that if there are any _extra_ keys that you do not want to search the
        checkpoint for, for example if you add new layers/heads onto your module, you need to supply the prefix
        path to those keys in your model and they will be ignored. This latter feature is key for flexible fine-tuning
        strategies where you load weights partially from other models with partially overlapping structures.

    Args:
        model: Megatron model that you want to load weights into.
        distributed_checkpoint_dir: _description_
        skip_keys_with_these_prefixes: _description_
    """  # noqa: D205
    sharded_state_dict = {
        _munge_key_megatron_to_nemo2(k): _munge_sharded_tensor_key_megatron_to_nemo2(v)
        for k, v in model.sharded_state_dict().items()
        if not _key_in_filter(k, skip_keys_with_these_prefixes) and "_extra_state" not in k
    }
    dist_checkpointing.load(
        sharded_state_dict=sharded_state_dict,
        checkpoint_dir=str(Path(distributed_checkpoint_dir) / "weights"),
        strict=dist_checkpointing.serialization.StrictHandling.ASSUME_OK_UNEXPECTED,
    )

nemo1_to_nemo2_biobert_key_mapping(old_key, new_model_prefix='module', old_model_prefix='model', te_mapping=False)

此函数用于将旧的 nemo BERT 模型的键映射到新的 BioBERT 模型

参数

名称 类型 描述 默认值
old_key str

我们想要映射到预期新键名称的旧键。

必需
new_model_prefix str

基本权重的新键。如果您将其指向核心 Megatron 模型,请将其设置为“”。对于遵循标准的常规 nemo2 lightning 模块,请将其设置为“module”。默认为“module”。

'module'
old_model_prefix str

先前保存的权重前缀。默认为“model”,这是 nemo1 中的标准。

'model'

返回

名称 类型 描述
str str

新键名称

源代码位于 bionemo/llm/utils/weight_utils.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def nemo1_to_nemo2_biobert_key_mapping(  # noqa: D417
    old_key: str,
    new_model_prefix: str = "module",
    old_model_prefix: str = "model",
    te_mapping: bool = False,
) -> str:
    """This function is used to map the keys from the old nemo BERT models to the new BioBERT models

    Args:
        old_key (str): old key we want to map to the expected new key name.
        new_model_prefix (str, optional): The new key for the base weights.
            If you point this at the core megatron model set it to "".
            For the regular nemo2 lightning module following standards, set it to "module".
            Defaults to "module".
        old_model_prefix (str, optional): The previous saved weight prefix. Defaults to "model" which was the standard in nemo1.

    Returns:
        str: New key name
    """  # noqa: D415
    # add the . to the end of the input prefixes if they are not the empty string,
    #  unless the user has already done so.
    if old_model_prefix != "":
        old_model_prefix = f"{old_model_prefix.rstrip('.')}."
    if new_model_prefix != "":
        new_model_prefix = f"{new_model_prefix.rstrip('.')}."

    # This function is used to map the keys from the old nemo BERT models to the new BioBERT models
    base_rename = old_key.replace(f"{old_model_prefix}language_model.", f"{new_model_prefix}")
    base_rename = base_rename.replace(f"{old_model_prefix}", f"{new_model_prefix}")
    if "dense_h_to_4h" in base_rename:
        return base_rename.replace("dense_h_to_4h", "linear_fc1")
    if "dense_4h_to_h" in base_rename:
        return base_rename.replace("dense_4h_to_h", "linear_fc2")
    if "query_key_value" in base_rename:
        return base_rename.replace("query_key_value", "linear_qkv")
    if "self_attention.dense" in base_rename:
        #  This is definitely the linear_proj and not the qkv. The linear_proj shapes are 256x256
        #   which match dense but not query_key_value
        # (Pdb) new_state_dict['encoder.layers.4.self_attention.linear_proj.weight'].shape
        #  torch.Size([256, 256])
        # (Pdb) new_state_dict['encoder.layers.4.self_attention.linear_qkv.weight'].shape
        # torch.Size([768, 256])
        # (Pdb) new_state_dict['encoder.layers.4.self_attention.linear_qkv.bias'].shape
        # torch.Size([768])
        return base_rename.replace("self_attention.dense", "self_attention.linear_proj")
    if "lm_head.bias" in base_rename:
        return base_rename.replace("lm_head.bias", "output_layer.bias")
    if "lm_head.weight" in base_rename:
        return base_rename.replace("lm_head.weight", "output_layer.weight")
    if "lm_head.layernorm" in base_rename:
        return base_rename.replace("lm_head.layernorm", "lm_head.layer_norm")

    if "post_attention_layernorm" in base_rename:
        base_rename = base_rename.replace("post_attention_layernorm", "pre_mlp_layernorm")

    # Handle the transformer engine spec's differences in layer naming and where things like layernorm are stored.
    #  TE moves layernorm from  an object that's part of the main attention layer to being an internal component of
    #  the linear layers, probably for efficiency/fusion of some sort.
    if te_mapping:
        if ".input_layernorm.weight" in base_rename:
            return base_rename.replace(".input_layernorm.weight", ".self_attention.linear_qkv.layer_norm_weight")
        if ".input_layernorm.bias" in base_rename:
            return base_rename.replace(".input_layernorm.bias", ".self_attention.linear_qkv.layer_norm_bias")
        if ".pre_mlp_layernorm.bias" in base_rename:
            return base_rename.replace(".pre_mlp_layernorm.bias", ".mlp.linear_fc1.layer_norm_bias")
        if ".pre_mlp_layernorm.weight" in base_rename:
            return base_rename.replace(".pre_mlp_layernorm.weight", ".mlp.linear_fc1.layer_norm_weight")
    return base_rename