跳转到内容

配置模型

ESM2DataConfig

基类:DataConfig[ESMDataModule]

ESM2DataConfig 是用于设置 ESM2 预训练数据模块的配置类。

ESM2DataModule 实现了 ESM2 出版物中定义的面向集群的采样方法。

属性

名称 类型 描述
train_cluster_path 路径

训练集群数据的路径。

train_database_path 路径

训练数据库的路径。

valid_cluster_path 路径

验证集群数据的路径。

valid_database_path 路径

验证数据库的路径。

micro_batch_size int

微批次的大小。默认为 8。

result_dir str

用于存储结果的目录。默认为“./results”。

min_seq_length int

最小序列长度。默认为 128。

max_seq_length int

最大序列长度。默认为 128。

random_mask_strategy RandomMaskStrategy

随机掩码策略。默认为 RandomMaskStrategy.ALL_TOKENS。

num_dataset_workers int

数据集的工作线程数。默认为 0。

方法

名称 描述
construct_data_module

int) -> ESMDataModule:构建并返回具有提供的全局批次大小的 ESMDataModule 实例。

源代码位于 bionemo/esm2/run/config_models.py
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
class ESM2DataConfig(DataConfig[ESMDataModule]):
    """ESM2DataConfig is a configuration class for setting up the pre-training data module for ESM2.

    The ESM2DataModule implements the cluster oriented sampling method defined in the ESM2 publication.

    Attributes:
        train_cluster_path (Path): Path to the training cluster data.
        train_database_path (Path): Path to the training database.
        valid_cluster_path (Path): Path to the validation cluster data.
        valid_database_path (Path): Path to the validation database.
        micro_batch_size (int): Size of the micro-batch. Default is 8.
        result_dir (str): Directory to store results. Default is "./results".
        min_seq_length (int): Minimum sequence length. Default is 128.
        max_seq_length (int): Maximum sequence length. Default is 128.
        random_mask_strategy (RandomMaskStrategy): Strategy for random masking. Default is RandomMaskStrategy.ALL_TOKENS.
        num_dataset_workers (int): Number of workers for the dataset. Default is 0.

    Methods:
        construct_data_module(global_batch_size: int) -> ESMDataModule:
            Constructs and returns an ESMDataModule instance with the provided global batch size.
    """

    train_cluster_path: Path
    train_database_path: Path
    valid_cluster_path: Path
    valid_database_path: Path

    micro_batch_size: int = 8
    result_dir: str | Path = "./results"
    min_seq_length: int = 128
    max_seq_length: int = 128
    random_mask_strategy: RandomMaskStrategy = RandomMaskStrategy.ALL_TOKENS
    num_dataset_workers: int = 0

    @field_serializer(
        "train_cluster_path", "train_database_path", "valid_cluster_path", "valid_database_path", "result_dir"
    )
    def serialize_paths(self, value: Path) -> str:  # noqa: D102
        return serialize_path_or_str(value)

    @field_validator(
        "train_cluster_path", "train_database_path", "valid_cluster_path", "valid_database_path", "result_dir"
    )
    def deserialize_paths(cls, value: str) -> Path:  # noqa: D102
        return deserialize_str_to_path(value)

    @field_serializer("random_mask_strategy")
    def serialize_spec_option(self, value: RandomMaskStrategy) -> str:  # noqa: D102
        return value.value

    @field_validator("random_mask_strategy", mode="before")
    def deserialize_spec_option(cls, value: str) -> RandomMaskStrategy:  # noqa: D102
        return RandomMaskStrategy(value)

    def construct_data_module(self, global_batch_size: int) -> ESMDataModule:
        """Constructs and returns an ESMDataModule instance with the provided global batch size.

        This method provides means for constructing the datamodule, any pre-requisites for the DataModule should be
        aquired here. For example, tokenizers, preprocessing, may want to live in this method.

        Args:
            global_batch_size (int): Global batch size for the data module. Global batch size must be a function of
                parallelism settings and the `micro_batch_size` attribute. Since the DataConfig has no ownership over
                parallelism configuration, we expect someone higher up on the ownership chain to provide the value to
                this method.

        """
        tokenizer = get_tokenizer()
        data = ESMDataModule(
            train_cluster_path=self.train_cluster_path,
            train_database_path=self.train_database_path,
            valid_cluster_path=self.valid_cluster_path,
            valid_database_path=self.valid_database_path,
            global_batch_size=global_batch_size,
            micro_batch_size=self.micro_batch_size,
            min_seq_length=self.min_seq_length,
            max_seq_length=self.max_seq_length,
            num_workers=self.num_dataset_workers,
            random_mask_strategy=self.random_mask_strategy,
            tokenizer=tokenizer,
        )
        return data

construct_data_module(global_batch_size)

构建并返回具有提供的全局批次大小的 ESMDataModule 实例。

此方法提供了构建数据模块的方法,应在此处获取 DataModule 的任何先决条件。例如,分词器、预处理可能需要在此方法中实现。

参数

名称 类型 描述 默认
global_batch_size int

数据模块的全局批次大小。全局批次大小必须是并行设置和 micro_batch_size 属性的函数。由于 DataConfig 不拥有并行配置的所有权,我们希望所有权链中更高层级的人员向此方法提供值。

必需
源代码位于 bionemo/esm2/run/config_models.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def construct_data_module(self, global_batch_size: int) -> ESMDataModule:
    """Constructs and returns an ESMDataModule instance with the provided global batch size.

    This method provides means for constructing the datamodule, any pre-requisites for the DataModule should be
    aquired here. For example, tokenizers, preprocessing, may want to live in this method.

    Args:
        global_batch_size (int): Global batch size for the data module. Global batch size must be a function of
            parallelism settings and the `micro_batch_size` attribute. Since the DataConfig has no ownership over
            parallelism configuration, we expect someone higher up on the ownership chain to provide the value to
            this method.

    """
    tokenizer = get_tokenizer()
    data = ESMDataModule(
        train_cluster_path=self.train_cluster_path,
        train_database_path=self.train_database_path,
        valid_cluster_path=self.valid_cluster_path,
        valid_database_path=self.valid_database_path,
        global_batch_size=global_batch_size,
        micro_batch_size=self.micro_batch_size,
        min_seq_length=self.min_seq_length,
        max_seq_length=self.max_seq_length,
        num_workers=self.num_dataset_workers,
        random_mask_strategy=self.random_mask_strategy,
        tokenizer=tokenizer,
    )
    return data

ExposedESM2PretrainConfig

基类:ExposedModelConfig[ESM2Config]

用于 ESM2 预训练的配置类,带有选择的公开参数。

有关基类的属性和方法,请参阅继承的 ExposedModelConfig。使用此类作为自定义配置的模板或扩展。重要的是,此类应执行两项操作:选择要向用户公开的属性,以及提供任何属性的验证和序列化。

属性

名称 类型 描述
use_esm_attention bool

用于跳过 ESM2 自定义注意力以进行 TE 加速的标志。默认为 False。

token_dropout bool

用于启用令牌 dropout 的标志。默认为 True。

normalize_attention_scores bool

用于标准化注意力分数的标志。默认为 False。

variable_seq_lengths bool

用于启用可变序列长度的标志。默认为 False。

core_attention_override Optional[Type[Module]]

核心注意力模块的可选覆盖。默认为 None。

方法

名称 描述
restrict_biobert_spec_to_esm2

BiobertSpecOption) -> BiobertSpecOption:验证 BiobertSpecOption 以确保其与 ESM2 兼容。

serialize_core_attention_override

Optional[Type[torch.nn.Module]]) -> Optional[str]:将核心注意力覆盖模块序列化为字符串。

validate_core_attention_override

验证核心注意力覆盖模块,确保它是 torch.nn.Module 的子类。

validate_and_set_attention_and_scaling

根据 biobert_spec_option 验证并设置注意力和缩放参数。

model_validator

MainConfig) -> MainConfig:验证全局配置,确保与 ESM2DataConfig 和并行设置兼容。

model_class

返回与此配置关联的模型类。

源代码位于 bionemo/esm2/run/config_models.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
class ExposedESM2PretrainConfig(ExposedModelConfig[ESM2Config]):
    """Configuration class for ESM2 pretraining with select exposed parameters.

    See the inherited ExposedModelConfig for attributes and methods from the base class. Use this class either
    as a template or extension for custom configurations. Importantly, these kinds of classes should do two things,
    select attributes to expose to the user, and provide validation and serialization any attributes.

    Attributes:
        use_esm_attention (bool): Flag to skip ESM2 custom attention for TE acceleration. Defaults to False.
        token_dropout (bool): Flag to enable token dropout. Defaults to True.
        normalize_attention_scores (bool): Flag to normalize attention scores. Defaults to False.
        variable_seq_lengths (bool): Flag to enable variable sequence lengths. Defaults to False.
        core_attention_override (Optional[Type[torch.nn.Module]]): Optional override for core attention module. Defaults to None.

    Methods:
        restrict_biobert_spec_to_esm2(cls, biobert_spec_option: BiobertSpecOption) -> BiobertSpecOption:
            Validates the BiobertSpecOption to ensure it is compatible with ESM2.
        serialize_core_attention_override(self, value: Optional[Type[torch.nn.Module]]) -> Optional[str]:
            Serializes the core attention override module to a string.
        validate_core_attention_override(cls, value):
            Validates the core attention override module, ensuring it is a subclass of torch.nn.Module.
        validate_and_set_attention_and_scaling(self):
            Validates and sets the attention and scaling parameters based on the biobert_spec_option.
        model_validator(self, global_cfg: MainConfig) -> MainConfig:
            Validates the global configuration, ensuring compatibility with ESM2DataConfig and parallel settings.
        model_class(self) -> Type[ESM2Config]:
            Returns the model class associated with this configuration.
    """

    use_esm_attention: bool = False  # Skip ESM2 custom attention for TE acceleration. Still passes golden value test.
    token_dropout: bool = True
    normalize_attention_scores: bool = False
    variable_seq_lengths: bool = False
    core_attention_override: Type[torch.nn.Module] | None = None

    @field_serializer("core_attention_override")
    def serialize_core_attention_override(self, value: Optional[Type[torch.nn.Module]]) -> Optional[str]:
        """Serializes the core attention override module to a string."""
        if value is None:
            return None
        return f"{value.__module__}.{value.__name__}"

    @field_validator("core_attention_override", mode="before")
    def validate_core_attention_override(cls, value):
        """Validates the core attention override module, ensuring it is a subclass of torch.nn.Module."""
        if value is None:
            return None
        if isinstance(value, str):
            module_name, class_name = value.rsplit(".", 1)
            try:
                module = importlib.import_module(module_name)
                cls = getattr(module, class_name)
                if not issubclass(cls, torch.nn.Module):
                    raise ValueError(f"{cls} is not a subclass of torch.nn.Module")
                return cls
            except (ImportError, AttributeError):
                raise ValueError(f"Cannot import {value}")
        return value

    @model_validator(mode="after")
    def validate_and_set_attention_and_scaling(self):
        """Validates and sets the attention and scaling parameters based on the biobert_spec_option."""
        logging.info(
            "Mutating apply_query_key_layer_scaling and core_attention_override based on biobert_spec_option.."
        )
        if self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec:
            self.apply_query_key_layer_scaling = False
        elif self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_local_spec:
            logging.warning(
                "BiobertSpecOption.esm2_bert_layer_local_spec is deprecated. "
                "Use BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec instead."
            )
            self.apply_query_key_layer_scaling = True
        return self

    def model_validator(self, global_cfg: MainConfig) -> MainConfig:
        """Validates the global configuration, ensuring compatibility with ESM2DataConfig and parallel settings.

        The global validator acts on the MainConfig, this couples together the ESM2DataConfig with ESM2PretrainingConfig.
        Additionally, it provides validation for sequence length and parallelism settings.

        Args:
            global_cfg (MainConfig): The global configuration object.
        """
        global_cfg = super().model_validator(global_cfg)
        # Need to ensure that at the least we have access to min_seq_length and max_seq_length
        if not isinstance(global_cfg.data_config, ESM2DataConfig):
            raise TypeError(f"ESM2PretrainConfig requires ESM2DataConfig, got {global_cfg.data_config=}")

        pipeline_model_parallel_size, tensor_model_parallel_size = (
            global_cfg.parallel_config.pipeline_model_parallel_size,
            global_cfg.parallel_config.tensor_model_parallel_size,
        )
        min_seq_length, max_seq_length = global_cfg.data_config.min_seq_length, global_cfg.data_config.max_seq_length
        assert (
            self.variable_seq_lengths
            == (pipeline_model_parallel_size * tensor_model_parallel_size > 1 and min_seq_length != max_seq_length)
        ), "Must set variable_seq_lengths to True when min_seq_length != max_seq_length under pipeline or tensor parallelism."
        return global_cfg

    def model_class(self) -> Type[ESM2Config]:
        """Returns the model class associated with this configuration."""
        return ESM2Config

model_class()

返回与此配置关联的模型类。

源代码位于 bionemo/esm2/run/config_models.py
223
224
225
def model_class(self) -> Type[ESM2Config]:
    """Returns the model class associated with this configuration."""
    return ESM2Config

model_validator(global_cfg)

验证全局配置,确保与 ESM2DataConfig 和并行设置兼容。

全局验证器作用于 MainConfig,这会将 ESM2DataConfig 与 ESM2PretrainingConfig 耦合在一起。此外,它还为序列长度和并行设置提供验证。

参数

名称 类型 描述 默认
global_cfg MainConfig

全局配置对象。

必需
源代码位于 bionemo/esm2/run/config_models.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
def model_validator(self, global_cfg: MainConfig) -> MainConfig:
    """Validates the global configuration, ensuring compatibility with ESM2DataConfig and parallel settings.

    The global validator acts on the MainConfig, this couples together the ESM2DataConfig with ESM2PretrainingConfig.
    Additionally, it provides validation for sequence length and parallelism settings.

    Args:
        global_cfg (MainConfig): The global configuration object.
    """
    global_cfg = super().model_validator(global_cfg)
    # Need to ensure that at the least we have access to min_seq_length and max_seq_length
    if not isinstance(global_cfg.data_config, ESM2DataConfig):
        raise TypeError(f"ESM2PretrainConfig requires ESM2DataConfig, got {global_cfg.data_config=}")

    pipeline_model_parallel_size, tensor_model_parallel_size = (
        global_cfg.parallel_config.pipeline_model_parallel_size,
        global_cfg.parallel_config.tensor_model_parallel_size,
    )
    min_seq_length, max_seq_length = global_cfg.data_config.min_seq_length, global_cfg.data_config.max_seq_length
    assert (
        self.variable_seq_lengths
        == (pipeline_model_parallel_size * tensor_model_parallel_size > 1 and min_seq_length != max_seq_length)
    ), "Must set variable_seq_lengths to True when min_seq_length != max_seq_length under pipeline or tensor parallelism."
    return global_cfg

serialize_core_attention_override(value)

将核心注意力覆盖模块序列化为字符串。

源代码位于 bionemo/esm2/run/config_models.py
158
159
160
161
162
163
@field_serializer("core_attention_override")
def serialize_core_attention_override(self, value: Optional[Type[torch.nn.Module]]) -> Optional[str]:
    """Serializes the core attention override module to a string."""
    if value is None:
        return None
    return f"{value.__module__}.{value.__name__}"

validate_and_set_attention_and_scaling()

根据 biobert_spec_option 验证并设置注意力和缩放参数。

源代码位于 bionemo/esm2/run/config_models.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
@model_validator(mode="after")
def validate_and_set_attention_and_scaling(self):
    """Validates and sets the attention and scaling parameters based on the biobert_spec_option."""
    logging.info(
        "Mutating apply_query_key_layer_scaling and core_attention_override based on biobert_spec_option.."
    )
    if self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec:
        self.apply_query_key_layer_scaling = False
    elif self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_local_spec:
        logging.warning(
            "BiobertSpecOption.esm2_bert_layer_local_spec is deprecated. "
            "Use BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec instead."
        )
        self.apply_query_key_layer_scaling = True
    return self

validate_core_attention_override(value)

验证核心注意力覆盖模块,确保它是 torch.nn.Module 的子类。

源代码位于 bionemo/esm2/run/config_models.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
@field_validator("core_attention_override", mode="before")
def validate_core_attention_override(cls, value):
    """Validates the core attention override module, ensuring it is a subclass of torch.nn.Module."""
    if value is None:
        return None
    if isinstance(value, str):
        module_name, class_name = value.rsplit(".", 1)
        try:
            module = importlib.import_module(module_name)
            cls = getattr(module, class_name)
            if not issubclass(cls, torch.nn.Module):
                raise ValueError(f"{cls} is not a subclass of torch.nn.Module")
            return cls
        except (ImportError, AttributeError):
            raise ValueError(f"Cannot import {value}")
    return value