跳到内容

ESM2QueryScaling

基类:Module

源代码位于 bionemo/llm/model/layers.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class ESM2QueryScaling(torch.nn.Module):  # noqa: D101
    def __init__(self, config: TransformerConfig, *args, **kwargs) -> None:  # noqa: D417
        """A custom layer that scales quary values.

        This layer should replace the q_layernorm=IdentityOp in ESM2 ModuleSpec to reproduce ESM2
        which apply 1/sqrt(hidden_size_per_attention_head) scaling prior to apply_rotary_pos_emb()

        Args:
            config (TransformerConfig): The megatron config. This is used for computing projection_size
        """
        super().__init__()
        projection_size = config.kv_channels * config.num_attention_heads
        self.hidden_size_per_attention_head = divide(projection_size, config.num_attention_heads)
        self.sqrt_val = math.sqrt(self.hidden_size_per_attention_head)

    @torch.compile
    def forward(self, query, *args, **kwargs):  # noqa: D102
        return query / self.sqrt_val

__init__(config, *args, **kwargs)

一个自定义层,用于缩放查询值。

此层应替换 ESM2 ModuleSpec 中的 q_layernorm=IdentityOp,以重现 ESM2,后者在 apply_rotary_pos_emb() 之前应用 1/sqrt(hidden_size_per_attention_head) 缩放。

参数

名称 类型 描述 默认值
config TransformerConfig

Megatron 配置。用于计算 projection_size

必需
源代码位于 bionemo/llm/model/layers.py
46
47
48
49
50
51
52
53
54
55
56
57
58
def __init__(self, config: TransformerConfig, *args, **kwargs) -> None:  # noqa: D417
    """A custom layer that scales quary values.

    This layer should replace the q_layernorm=IdentityOp in ESM2 ModuleSpec to reproduce ESM2
    which apply 1/sqrt(hidden_size_per_attention_head) scaling prior to apply_rotary_pos_emb()

    Args:
        config (TransformerConfig): The megatron config. This is used for computing projection_size
    """
    super().__init__()
    projection_size = config.kv_channels * config.num_attention_heads
    self.hidden_size_per_attention_head = divide(projection_size, config.num_attention_heads)
    self.sqrt_val = math.sqrt(self.hidden_size_per_attention_head)

TELayerNorm

基类:LayerNorm

源代码位于 bionemo/llm/model/layers.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class TELayerNorm(te.pytorch.LayerNorm):  # noqa: D101
    def __init__(self, config: TransformerConfig, *args, **kwargs) -> None:  # noqa: D417
        """A wrapper around transformer engine layernorm that allows it to be initialized with a TransformerConfig.
            This allows this method to be used in a megatron layerspec.

        Args:
            config (TransformerConfig): The megatron config. This is used for extracing sequence_parallel and zero_centered_gamma.
                The rest of the config is not used.
        """  # noqa: D205
        # Eps tends to get passed through properly, as does hidden_size, but not other params from the config.
        super().__init__(
            *args,
            zero_centered_gamma=config.layernorm_zero_centered_gamma,
            sequence_parallel=config.sequence_parallel,
            **kwargs,
        )

__init__(config, *args, **kwargs)

Transformer Engine LayerNorm 的包装器,允许使用 TransformerConfig 进行初始化。这允许此方法在 Megatron layerspec 中使用。

参数

名称 类型 描述 默认值
config TransformerConfig

Megatron 配置。用于提取 sequence_parallel 和 zero_centered_gamma。配置的其余部分未使用。

必需
源代码位于 bionemo/llm/model/layers.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def __init__(self, config: TransformerConfig, *args, **kwargs) -> None:  # noqa: D417
    """A wrapper around transformer engine layernorm that allows it to be initialized with a TransformerConfig.
        This allows this method to be used in a megatron layerspec.

    Args:
        config (TransformerConfig): The megatron config. This is used for extracing sequence_parallel and zero_centered_gamma.
            The rest of the config is not used.
    """  # noqa: D205
    # Eps tends to get passed through properly, as does hidden_size, but not other params from the config.
    super().__init__(
        *args,
        zero_centered_gamma=config.layernorm_zero_centered_gamma,
        sequence_parallel=config.sequence_parallel,
        **kwargs,
    )