跳到内容

Api

ESM2Config dataclass

基类: ESM2GenericConfig, IOMixinWithGettersSetters

ESM2 模型的配置类。

源代码位于 bionemo/esm2/model/model.py
358
359
360
361
362
363
364
@dataclass
class ESM2Config(ESM2GenericConfig, iom.IOMixinWithGettersSetters):
    """Configuration class for ESM2 model."""

    model_cls: Type[ESM2Model] = ESM2Model
    num_layers: int = 33  # 650M
    hidden_size: int = 1280  # 650M

ESM2GenericConfig dataclass

基类: BioBertConfig[ESM2ModelT, MegatronLossType]

ESM2 模型的配置类。

属性

名称 类型 描述
num_layers int

模型中的层数。

hidden_size int

模型的隐藏层大小。

num_attention_heads int

模型中的注意力头数。

ffn_hidden_size int

前馈网络的隐藏层大小。

hidden_dropout float

隐藏层的 Dropout 率。

attention_dropout float

注意力层的 Dropout 率。

apply_residual_connection_post_layernorm bool

是否在层归一化后应用残差连接。

layernorm_epsilon float

层归一化的 Epsilon 值。

layernorm_zero_centered_gamma float

是否在层归一化中将 gamma 参数零中心化。

activation_func Callable

模型中使用的激活函数。

init_method_std float

权重初始化的标准差。

apply_query_key_layer_scaling float

是否对查询和键层应用缩放。

masked_softmax_fusion float

是否使用将注意力 softmax 与其掩码融合的内核。

fp16_lm_cross_entropy bool

是否将用于语言模型头的交叉熵未缩减损失计算移至 fp16。

share_embeddings_and_output_weights bool

是否共享嵌入和输出权重。

enable_autocast bool

是否为混合精度启用自动类型转换。

biobert_spec_option BiobertSpecOption

模型的 BiobertSpecOption。

position_embedding_type PositionEmbeddingKinds

模型中使用的位置嵌入类型。

seq_length int

输入序列的长度。

make_vocab_size_divisible_by int

使词汇表大小可被此值整除。

token_dropout bool

是否应用 token dropout。

use_attention_mask bool

是否使用注意力掩码。

use_esm_attention bool

是否使用 ESM 注意力。

attention_softmax_in_fp32 bool

是否对注意力 softmax 使用 fp32。

optimizer_fn Optional[Callable[[MegatronBioBertModel], Optimizer]]

模型的可选优化器函数。

parallel_output bool

是否使用并行输出。

rotary_base int

旋转位置编码的基值。

rotary_percent float

旋转位置编码的百分比。

seq_len_interpolation_factor Optional[float]

序列长度的插值因子。

get_attention_mask_from_fusion Optional[float]

是否从融合中获取注意力掩码。

nemo1_ckpt_path str | None

NEMO1 检查点的路径。

return_only_hidden_states bool

是否仅返回隐藏状态。

loss_reduction_class bool

模型的损失缩减类。默认为 BERTMLMLossWithReduction。

源代码位于 bionemo/esm2/model/model.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
@dataclass
class ESM2GenericConfig(BioBertConfig[ESM2ModelT, MegatronLossType]):
    """Configuration class for ESM2 model.

    Attributes:
        num_layers: Number of layers in the model.
        hidden_size: Hidden size of the model.
        num_attention_heads: Number of attention heads in the model.
        ffn_hidden_size: Hidden size of the feed-forward network.
        hidden_dropout: Dropout rate for hidden layers.
        attention_dropout: Dropout rate for attention layers.
        apply_residual_connection_post_layernorm: Whether to apply residual connection after layer normalization.
        layernorm_epsilon: Epsilon value for layer normalization.
        layernorm_zero_centered_gamma: Whether to zero-center the gamma parameter in layer normalization.
        activation_func: Activation function used in the model.
        init_method_std: Standard deviation for weight initialization.
        apply_query_key_layer_scaling: Whether to apply scaling to query and key layers.
        masked_softmax_fusion: Whether to use a kernel that fuses attention softmax with its mask.
        fp16_lm_cross_entropy: Whether to move the cross entropy unreduced loss calculation for lm head to fp16.
        share_embeddings_and_output_weights: Whether to share embeddings and output weights.
        enable_autocast: Whether to enable autocast for mixed precision.
        biobert_spec_option: BiobertSpecOption for the model.
        position_embedding_type: Type of position embedding used in the model.
        seq_length: Length of the input sequence.
        make_vocab_size_divisible_by: Make the vocabulary size divisible by this value.
        token_dropout: Whether to apply token dropout.
        use_attention_mask: Whether to use attention mask.
        use_esm_attention: Whether to use ESM attention.
        attention_softmax_in_fp32: Whether to use fp32 for attention softmax.
        optimizer_fn: Optional optimizer function for the model.
        parallel_output: Whether to use parallel output.
        rotary_base: Base value for rotary positional encoding.
        rotary_percent: Percentage of rotary positional encoding.
        seq_len_interpolation_factor: Interpolation factor for sequence length.
        get_attention_mask_from_fusion: Whether to get attention mask from fusion.
        nemo1_ckpt_path: Path to NEMO1 checkpoint.
        return_only_hidden_states: Whether to return only hidden states.
        loss_reduction_class: Loss reduction class for the model. Default to BERTMLMLossWithReduction.
    """

    # When overriding fields in a dataclass _always_ declare types: https://github.com/python/cpython/issues/123269
    model_cls: Type[ESM2ModelT] = ESM2Model
    num_layers: int = 33  # 650M
    hidden_size: int = 1280  # 650M
    num_attention_heads: int = 20
    ffn_hidden_size: int = 4 * 1280  # Transformer FFN hidden size. Usually 4 * hidden_size.
    hidden_dropout: float = 0  # ESM2 removes dropout from hidden layers and attention
    attention_dropout: float = 0.0  # ESM2 does not use attention dropout
    apply_residual_connection_post_layernorm: bool = False  # TODO: farhadr False is new default, True was BERT pub.
    layernorm_epsilon: float = 1.0e-5
    bias_activation_fusion: bool = True  # True degrades accuracy slightly, but is faster.
    activation_func: Callable = F.gelu  # esm_gelu_func  # ESM2 MLP
    init_method_std: float = 0.02
    softmax_scale: float = 1.0

    # embedding
    token_dropout: bool = True
    use_attention_mask: bool = True

    # core attention
    use_esm_attention: bool = False  # Skip ESM2 custom attention for TE acceleration. Still passes golden value test.
    attention_softmax_in_fp32: bool = False
    normalize_attention_scores: bool = False

    # From megatron.core.models.gpt.bert_model.GPTModel
    fp16_lm_cross_entropy: bool = False  # Move the cross entropy unreduced loss calculation for lm head to fp16
    parallel_output: bool = True
    share_embeddings_and_output_weights: bool = True
    make_vocab_size_divisible_by: int = 128
    position_embedding_type: PositionEmbeddingKinds = "rope"  # ESM2 uses relative positional encoding 'ROPE' to extrapolate to longer sequences unseen during training
    rotary_base: int = 10000
    rotary_percent: float = 1.0
    seq_len_interpolation_factor: Optional[float] = None
    seq_length: int = 1024
    biobert_spec_option: BiobertSpecOption = BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec

    optimizer_fn: Optional[Callable[[MegatronBioBertModel], Optimizer]] = None
    # TODO (@skothenhill,@georgea) update to use the nemo2 checkpoint mixins
    #  support HF (requires weight interleaving on qkv layer) and nemo1 checkpoints ideally.
    nemo1_ckpt_path: str | None = None
    # The following checkpoint path is for nemo2 checkpoints. Config parameters not present in
    #  self.override_parent_fields will be loaded from the checkpoint and override those values here.
    initial_ckpt_path: str | None = None
    # TODO (@jstjohn) come up with a cleaner way in the biobert module to return user requested
    #  things as part of the workflow for inference and fine-tuning.
    return_embeddings: bool = False
    include_embeddings: bool = False
    include_input_ids: bool = False
    skip_logits: bool = False
    return_only_hidden_states: bool = False  # return logits

    def __post_init__(self):
        # TODO, as a validator?
        """Check configuration compatibility."""
        # reset moe_token_dispatcher_type when variable_seq_lengths is True.
        # must be performed before super().__post_init__()
        if self.variable_seq_lengths and self.moe_token_dispatcher_type in ["allgather", "alltoall_seq"]:
            logging.warning(
                "MoE token dispatcher type 'allgather' and 'alltoall_seq' are not supported with variable sequence lengths. Setting moe_token_dispatcher_type to 'alltoall'."
            )
            self.moe_token_dispatcher_type = "alltoall"

        # reset apply_query_key_layer_scaling based on biobert_spec_option
        super().__post_init__()
        if self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec:
            self.apply_query_key_layer_scaling = False
        elif self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_local_spec:
            logging.warning(
                "BiobertSpecOption.esm2_bert_layer_local_spec is depreciated. Use BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec instead."
            )
            self.apply_query_key_layer_scaling = True
        else:
            raise ValueError(f"Unknown biobert_spec_option: {self.biobert_spec_option}")

__post_init__()

检查配置兼容性。

源代码位于 bionemo/esm2/model/model.py
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
def __post_init__(self):
    # TODO, as a validator?
    """Check configuration compatibility."""
    # reset moe_token_dispatcher_type when variable_seq_lengths is True.
    # must be performed before super().__post_init__()
    if self.variable_seq_lengths and self.moe_token_dispatcher_type in ["allgather", "alltoall_seq"]:
        logging.warning(
            "MoE token dispatcher type 'allgather' and 'alltoall_seq' are not supported with variable sequence lengths. Setting moe_token_dispatcher_type to 'alltoall'."
        )
        self.moe_token_dispatcher_type = "alltoall"

    # reset apply_query_key_layer_scaling based on biobert_spec_option
    super().__post_init__()
    if self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec:
        self.apply_query_key_layer_scaling = False
    elif self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_local_spec:
        logging.warning(
            "BiobertSpecOption.esm2_bert_layer_local_spec is depreciated. Use BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec instead."
        )
        self.apply_query_key_layer_scaling = True
    else:
        raise ValueError(f"Unknown biobert_spec_option: {self.biobert_spec_option}")

ESM2Model

基类: MegatronBioBertModel

ESM2 Transformer 语言模型。

源代码位于 bionemo/esm2/model/model.py
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
class ESM2Model(MegatronBioBertModel):
    """ESM2 Transformer language model."""

    def __init__(
        self,
        config: TransformerConfig,
        num_tokentypes: int,
        transformer_layer_spec: spec_utils.ModuleSpec,
        vocab_size: int,
        max_sequence_length: int,
        tokenizer: Optional[BioNeMoESMTokenizer] = None,
        pre_process: bool = True,
        post_process: bool = True,
        fp16_lm_cross_entropy: bool = False,
        parallel_output: bool = True,
        share_embeddings_and_output_weights: bool = False,
        position_embedding_type: Literal["learned_absolute", "rope"] = "learned_absolute",
        rotary_percent: float = 1.0,
        seq_len_interpolation_factor: Optional[float] = None,
        add_binary_head: bool = True,
        return_embeddings: bool = False,
        include_embeddings: bool = False,
        include_input_ids: bool = False,
        use_full_attention_mask: bool = False,
        include_hiddens: bool = False,
        skip_logits: bool = False,
    ) -> None:
        """Initialize the ESM2 model.

        Args:
            config (TransformerConfig): transformer config
            num_tokentypes (int): Set to 2 when args.bert_binary_head is True, and 0 otherwise. Defaults to 0.
            transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers
            vocab_size (int): vocabulary size
            max_sequence_length (int): maximum size of sequence. This is used for positional embedding
            tokenizer (AutoTokenizer): optional tokenizer object (currently only used in the constructor of ESM2Model)
            pre_process (bool): Include embedding layer (used with pipeline parallelism)
            post_process (bool): Include an output layer (used with pipeline parallelism)
            fp16_lm_cross_entropy: Whether to move the cross entropy unreduced loss calculation for lm head to fp16.
            parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks
            share_embeddings_and_output_weights (bool): When True, input embeddings and output logit weights are shared. Defaults to False.
            position_embedding_type (string): Position embedding type. Options ['learned_absolute', 'rope'].
                Defaults is 'learned_absolute'.
            rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings.
                Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'.
            seq_len_interpolation_factor (Optional[float]): Interpolation factor for sequence length. Defaults to None.
            add_binary_head (bool): Whether to add a binary head. Defaults to True.
            return_embeddings (bool): Whether to return embeddings. Defaults to False.
            include_embeddings (bool): Whether to include embeddings in the output dictionary. Defaults to False.
            include_input_ids (bool): Whether to include input_ids in the output dictionary. Defaults to False.
            use_full_attention_mask (bool): Whether to use full attention mask. Defaults to False.
            include_hiddens (bool): Whether to include hidden states in the output dictionary. Defaults to False.
            skip_logits (bool): Skip writing the token logits in output dict
        """
        super(MegatronBioBertModel, self).__init__(config=config)
        self.post_process = post_process
        self.add_binary_head = add_binary_head
        if return_embeddings:
            assert self.post_process, "only return embeddings on the last pipeline stage"
        # `b` = batch, `s` = sequence.
        # The old flash attention mechanism apparently wants you to use a b x 1 x s x s attention mask while
        #  the new one wants a b x 1 x 1 x s attention mask. This is a hack to allow us to switch between the two.
        self.use_full_attention_mask = use_full_attention_mask
        self.config: TransformerConfig = config
        self.transformer_layer_spec: spec_utils.ModuleSpec = transformer_layer_spec
        self.vocab_size = vocab_size
        self.max_sequence_length = max_sequence_length
        self.pre_process = pre_process
        self.post_process = post_process
        self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
        self.parallel_output = parallel_output
        self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
        self.position_embedding_type = position_embedding_type
        self.add_binary_head = add_binary_head
        self.return_embeddings = return_embeddings
        self.include_embeddings = include_embeddings
        self.include_hiddens = include_hiddens
        self.include_input_ids = include_input_ids
        self.skip_logits = skip_logits

        # megatron core pipelining currently depends on model type
        self.model_type = ModelType.encoder_or_decoder

        # Embeddings.
        if self.pre_process:
            self.register_buffer(
                "bert_position_id_tensor",
                torch.arange(max_sequence_length, dtype=torch.long, requires_grad=False).unsqueeze(0),
                persistent=False,
            )
            # ESM2 Customization: ESM2Embedding instead of LanguageModelEmbedding
            # TODO: call super, overwrite the self.embedding, and setup_embeddings_and_output_layer in constructor.
            # Note: need to avoid calling setup twice: skip with super (super(skip_setup=True))
            self.embedding = ESM2Embedding(
                config=self.config,
                vocab_size=self.vocab_size,
                max_sequence_length=self.max_sequence_length,
                position_embedding_type=position_embedding_type,
                num_tokentypes=num_tokentypes,
                # ESM2 NEW ARGS
                token_dropout=self.config.token_dropout,
                use_attention_mask=self.config.use_attention_mask,
                mask_token_id=tokenizer.mask_token_id,
            )

        if self.position_embedding_type == "rope":
            self.rotary_pos_emb = RotaryEmbedding(
                kv_channels=self.config.kv_channels,
                rotary_percent=rotary_percent,
                rotary_interleaved=self.config.rotary_interleaved,
                seq_len_interpolation_factor=seq_len_interpolation_factor,
            )

        # Transformer.
        self.encoder = TransformerBlock(
            config=self.config,
            spec=self.transformer_layer_spec,
            pre_process=self.pre_process,
            post_process=self.post_process,
        )

        # Output
        if post_process:
            # TODO: Make sure you are passing in the mpu_vocab_size properly
            self.lm_head = BertLMHead(
                config.hidden_size,
                config,
            )

            self.output_layer = tensor_parallel.ColumnParallelLinear(
                config.hidden_size,
                self.vocab_size,
                config=config,
                init_method=config.init_method,
                bias=True,
                skip_bias_add=False,
                gather_output=not self.parallel_output,
                skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights,
            )

            self.binary_head = None
            if self.add_binary_head:
                # TODO: Shoudl switch this to TE ?
                self.binary_head = get_linear_layer(
                    config.hidden_size, 2, config.init_method, config.perform_initialization
                )

                self.pooler = Pooler(config.hidden_size, config.init_method, config, config.sequence_parallel)
        if self.pre_process or self.post_process:
            self.setup_embeddings_and_output_layer()

    def embedding_forward(
        self, input_ids: Tensor, position_ids: Tensor, tokentype_ids: Tensor = None, attention_mask: Tensor = None
    ):
        """Forward pass of the embedding layer.

        Args:
            input_ids: The input tensor of shape (batch_size, sequence_length) containing the input IDs.
            position_ids: The tensor of shape (batch_size, sequence_length) containing the position IDs.
            tokentype_ids: The tensor of shape (batch_size, sequence_length) containing the token type IDs. Defaults to None.
            attention_mask: The tensor of shape (batch_size, sequence_length) containing the attention mask. Defaults to None.

        Returns:
            Tensor: The output tensor of shape (batch_size, sequence_length, hidden_size) containing the embedded representations.
        """
        # ESM2 Customization: ESM2Embedding forward takes attention_mask
        # in addition to the args required by LanguageModelEmbedding
        return self.embedding(
            input_ids=input_ids, position_ids=position_ids, tokentype_ids=tokentype_ids, attention_mask=attention_mask
        )

__init__(config, num_tokentypes, transformer_layer_spec, vocab_size, max_sequence_length, tokenizer=None, pre_process=True, post_process=True, fp16_lm_cross_entropy=False, parallel_output=True, share_embeddings_and_output_weights=False, position_embedding_type='learned_absolute', rotary_percent=1.0, seq_len_interpolation_factor=None, add_binary_head=True, return_embeddings=False, include_embeddings=False, include_input_ids=False, use_full_attention_mask=False, include_hiddens=False, skip_logits=False)

初始化 ESM2 模型。

参数

名称 类型 描述 默认值
config TransformerConfig

Transformer 配置

必需
num_tokentypes int

当 args.bert_binary_head 为 True 时设置为 2,否则设置为 0。默认为 0。

必需
transformer_layer_spec ModuleSpec

指定用于 Transformer 层的模块

必需
vocab_size int

词汇表大小

必需
max_sequence_length int

序列的最大大小。这用于位置嵌入

必需
tokenizer AutoTokenizer

可选的 tokenizer 对象(目前仅在 ESM2Model 的构造函数中使用)

None
pre_process bool

包含嵌入层(与流水线并行一起使用)

True
post_process bool

包含输出层(与流水线并行一起使用)

True
fp16_lm_cross_entropy bool

是否将用于语言模型头的交叉熵未缩减损失计算移至 fp16。

False
parallel_output bool

不收集输出,保持它们在张量并行 ranks 中拆分

True
share_embeddings_and_output_weights bool

当为 True 时,输入嵌入和输出 logits 权重是共享的。默认为 False。

False
position_embedding_type string

位置嵌入类型。选项 ['learned_absolute', 'rope']。默认值为 'learned_absolute'。

'learned_absolute'
rotary_percent float

用于旋转位置嵌入的旋转维度的百分比。默认为 1.0 (100%)。除非 position_embedding_type 为 'rope',否则忽略。

1.0
seq_len_interpolation_factor Optional[float]

序列长度的插值因子。默认为 None。

None
add_binary_head bool

是否添加二元头。默认为 True。

True
return_embeddings bool

是否返回嵌入。默认为 False。

False
include_embeddings bool

是否在输出字典中包含嵌入。默认为 False。

False
include_input_ids bool

是否在输出字典中包含 input_ids。默认为 False。

False
use_full_attention_mask bool

是否使用完整注意力掩码。默认为 False。

False
include_hiddens bool

是否在输出字典中包含隐藏状态。默认为 False。

False
skip_logits bool

跳过在输出字典中写入 token logits

False
源代码位于 bionemo/esm2/model/model.py
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def __init__(
    self,
    config: TransformerConfig,
    num_tokentypes: int,
    transformer_layer_spec: spec_utils.ModuleSpec,
    vocab_size: int,
    max_sequence_length: int,
    tokenizer: Optional[BioNeMoESMTokenizer] = None,
    pre_process: bool = True,
    post_process: bool = True,
    fp16_lm_cross_entropy: bool = False,
    parallel_output: bool = True,
    share_embeddings_and_output_weights: bool = False,
    position_embedding_type: Literal["learned_absolute", "rope"] = "learned_absolute",
    rotary_percent: float = 1.0,
    seq_len_interpolation_factor: Optional[float] = None,
    add_binary_head: bool = True,
    return_embeddings: bool = False,
    include_embeddings: bool = False,
    include_input_ids: bool = False,
    use_full_attention_mask: bool = False,
    include_hiddens: bool = False,
    skip_logits: bool = False,
) -> None:
    """Initialize the ESM2 model.

    Args:
        config (TransformerConfig): transformer config
        num_tokentypes (int): Set to 2 when args.bert_binary_head is True, and 0 otherwise. Defaults to 0.
        transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers
        vocab_size (int): vocabulary size
        max_sequence_length (int): maximum size of sequence. This is used for positional embedding
        tokenizer (AutoTokenizer): optional tokenizer object (currently only used in the constructor of ESM2Model)
        pre_process (bool): Include embedding layer (used with pipeline parallelism)
        post_process (bool): Include an output layer (used with pipeline parallelism)
        fp16_lm_cross_entropy: Whether to move the cross entropy unreduced loss calculation for lm head to fp16.
        parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks
        share_embeddings_and_output_weights (bool): When True, input embeddings and output logit weights are shared. Defaults to False.
        position_embedding_type (string): Position embedding type. Options ['learned_absolute', 'rope'].
            Defaults is 'learned_absolute'.
        rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings.
            Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'.
        seq_len_interpolation_factor (Optional[float]): Interpolation factor for sequence length. Defaults to None.
        add_binary_head (bool): Whether to add a binary head. Defaults to True.
        return_embeddings (bool): Whether to return embeddings. Defaults to False.
        include_embeddings (bool): Whether to include embeddings in the output dictionary. Defaults to False.
        include_input_ids (bool): Whether to include input_ids in the output dictionary. Defaults to False.
        use_full_attention_mask (bool): Whether to use full attention mask. Defaults to False.
        include_hiddens (bool): Whether to include hidden states in the output dictionary. Defaults to False.
        skip_logits (bool): Skip writing the token logits in output dict
    """
    super(MegatronBioBertModel, self).__init__(config=config)
    self.post_process = post_process
    self.add_binary_head = add_binary_head
    if return_embeddings:
        assert self.post_process, "only return embeddings on the last pipeline stage"
    # `b` = batch, `s` = sequence.
    # The old flash attention mechanism apparently wants you to use a b x 1 x s x s attention mask while
    #  the new one wants a b x 1 x 1 x s attention mask. This is a hack to allow us to switch between the two.
    self.use_full_attention_mask = use_full_attention_mask
    self.config: TransformerConfig = config
    self.transformer_layer_spec: spec_utils.ModuleSpec = transformer_layer_spec
    self.vocab_size = vocab_size
    self.max_sequence_length = max_sequence_length
    self.pre_process = pre_process
    self.post_process = post_process
    self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
    self.parallel_output = parallel_output
    self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
    self.position_embedding_type = position_embedding_type
    self.add_binary_head = add_binary_head
    self.return_embeddings = return_embeddings
    self.include_embeddings = include_embeddings
    self.include_hiddens = include_hiddens
    self.include_input_ids = include_input_ids
    self.skip_logits = skip_logits

    # megatron core pipelining currently depends on model type
    self.model_type = ModelType.encoder_or_decoder

    # Embeddings.
    if self.pre_process:
        self.register_buffer(
            "bert_position_id_tensor",
            torch.arange(max_sequence_length, dtype=torch.long, requires_grad=False).unsqueeze(0),
            persistent=False,
        )
        # ESM2 Customization: ESM2Embedding instead of LanguageModelEmbedding
        # TODO: call super, overwrite the self.embedding, and setup_embeddings_and_output_layer in constructor.
        # Note: need to avoid calling setup twice: skip with super (super(skip_setup=True))
        self.embedding = ESM2Embedding(
            config=self.config,
            vocab_size=self.vocab_size,
            max_sequence_length=self.max_sequence_length,
            position_embedding_type=position_embedding_type,
            num_tokentypes=num_tokentypes,
            # ESM2 NEW ARGS
            token_dropout=self.config.token_dropout,
            use_attention_mask=self.config.use_attention_mask,
            mask_token_id=tokenizer.mask_token_id,
        )

    if self.position_embedding_type == "rope":
        self.rotary_pos_emb = RotaryEmbedding(
            kv_channels=self.config.kv_channels,
            rotary_percent=rotary_percent,
            rotary_interleaved=self.config.rotary_interleaved,
            seq_len_interpolation_factor=seq_len_interpolation_factor,
        )

    # Transformer.
    self.encoder = TransformerBlock(
        config=self.config,
        spec=self.transformer_layer_spec,
        pre_process=self.pre_process,
        post_process=self.post_process,
    )

    # Output
    if post_process:
        # TODO: Make sure you are passing in the mpu_vocab_size properly
        self.lm_head = BertLMHead(
            config.hidden_size,
            config,
        )

        self.output_layer = tensor_parallel.ColumnParallelLinear(
            config.hidden_size,
            self.vocab_size,
            config=config,
            init_method=config.init_method,
            bias=True,
            skip_bias_add=False,
            gather_output=not self.parallel_output,
            skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights,
        )

        self.binary_head = None
        if self.add_binary_head:
            # TODO: Shoudl switch this to TE ?
            self.binary_head = get_linear_layer(
                config.hidden_size, 2, config.init_method, config.perform_initialization
            )

            self.pooler = Pooler(config.hidden_size, config.init_method, config, config.sequence_parallel)
    if self.pre_process or self.post_process:
        self.setup_embeddings_and_output_layer()

embedding_forward(input_ids, position_ids, tokentype_ids=None, attention_mask=None)

嵌入层的前向传播。

参数

名称 类型 描述 默认值
input_ids Tensor

形状为 (batch_size, sequence_length) 的输入张量,包含输入 ID。

必需
position_ids Tensor

形状为 (batch_size, sequence_length) 的张量,包含位置 ID。

必需
tokentype_ids Tensor

形状为 (batch_size, sequence_length) 的张量,包含 token 类型 ID。默认为 None。

None
attention_mask Tensor

形状为 (batch_size, sequence_length) 的张量,包含注意力掩码。默认为 None。

None

返回

名称 类型 描述
Tensor

形状为 (batch_size, sequence_length, hidden_size) 的输出张量,包含嵌入表示。

源代码位于 bionemo/esm2/model/model.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
def embedding_forward(
    self, input_ids: Tensor, position_ids: Tensor, tokentype_ids: Tensor = None, attention_mask: Tensor = None
):
    """Forward pass of the embedding layer.

    Args:
        input_ids: The input tensor of shape (batch_size, sequence_length) containing the input IDs.
        position_ids: The tensor of shape (batch_size, sequence_length) containing the position IDs.
        tokentype_ids: The tensor of shape (batch_size, sequence_length) containing the token type IDs. Defaults to None.
        attention_mask: The tensor of shape (batch_size, sequence_length) containing the attention mask. Defaults to None.

    Returns:
        Tensor: The output tensor of shape (batch_size, sequence_length, hidden_size) containing the embedded representations.
    """
    # ESM2 Customization: ESM2Embedding forward takes attention_mask
    # in addition to the args required by LanguageModelEmbedding
    return self.embedding(
        input_ids=input_ids, position_ids=position_ids, tokentype_ids=tokentype_ids, attention_mask=attention_mask
    )