跳到内容

Infer esm2

get_parser()

返回此工具的 cli 解析器。

源代码位于 bionemo/esm2/scripts/infer_esm2.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
def get_parser():
    """Return the cli parser for this tool."""
    parser = argparse.ArgumentParser(description="Infer ESM2.")
    parser.add_argument(
        "--checkpoint-path",
        type=Path,
        required=True,
        help="Path to the ESM2 pretrained checkpoint",
    )
    parser.add_argument(
        "--data-path",
        type=Path,
        required=True,
        help="Path to the CSV file containing sequences and label columns",
    )
    parser.add_argument("--results-path", type=Path, required=True, help="Path to the results directory.")

    parser.add_argument(
        "--precision",
        type=str,
        choices=get_args(PrecisionTypes),
        required=False,
        default="bf16-mixed",
        help="Precision type to use for training.",
    )
    parser.add_argument(
        "--num-gpus",
        type=int,
        required=False,
        default=1,
        help="Number of GPUs to use for training. Default is 1.",
    )
    parser.add_argument(
        "--num-nodes",
        type=int,
        required=False,
        default=1,
        help="Number of nodes to use for training. Default is 1.",
    )
    parser.add_argument(
        "--micro-batch-size",
        type=int,
        required=False,
        default=2,
        help="Micro-batch size. Global batch size is inferred from this.",
    )
    parser.add_argument(
        "--pipeline-model-parallel-size",
        type=int,
        required=False,
        default=1,
        help="Pipeline model parallel size. Default is 1.",
    )
    parser.add_argument(
        "--tensor-model-parallel-size",
        type=int,
        required=False,
        default=1,
        help="Tensor model parallel size. Default is 1.",
    )
    parser.add_argument(
        "--prediction-interval",
        type=str,
        required=False,
        choices=get_args(IntervalT),
        default="epoch",
        help="Intervals to write DDP predictions into disk",
    )
    parser.add_argument(
        "--include-hiddens",
        action="store_true",
        default=False,
        help="Include hiddens in output of inference",
    )
    parser.add_argument(
        "--include-input-ids",
        action="store_true",
        default=False,
        help="Include input_ids in output of inference",
    )
    parser.add_argument(
        "--include-embeddings",
        action="store_true",
        default=False,
        help="Include embeddings in output of inference",
    )
    parser.add_argument(
        "--include-logits", action="store_true", default=False, help="Include per-token logits in output."
    )
    config_class_options: Dict[str, Type[BioBertConfig]] = SUPPORTED_CONFIGS

    def config_class_type(desc: str) -> Type[BioBertConfig]:
        try:
            return config_class_options[desc]
        except KeyError:
            raise argparse.ArgumentTypeError(
                f"Do not recognize key {desc}, valid options are: {config_class_options.keys()}"
            )

    parser.add_argument(
        "--config-class",
        type=config_class_type,
        default="ESM2Config",
        help="Model configs link model classes with losses, and handle model initialization (including from a prior "
        "checkpoint). This is how you can fine-tune a model. First train with one config class that points to one model "
        "class and loss, then implement and provide an alternative config class that points to a variant of that model "
        "and alternative loss. In the future this script should also provide similar support for picking different data "
        f"modules for fine-tuning with different data types. Choices: {config_class_options.keys()}",
    )
    return parser

infer_esm2_entrypoint()

在 geneformer 检查点和数据上运行推理的入口点。

源代码位于 bionemo/esm2/scripts/infer_esm2.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def infer_esm2_entrypoint():
    """Entrypoint for running inference on a geneformer checkpoint and data."""
    # 1. get arguments
    parser = get_parser()
    args = parser.parse_args()
    # 2. Call infer with args
    infer_model(
        data_path=args.data_path,
        checkpoint_path=args.checkpoint_path,
        results_path=args.results_path,
        include_hiddens=args.include_hiddens,
        include_embeddings=args.include_embeddings,
        include_logits=args.include_logits,
        include_input_ids=args.include_input_ids,
        micro_batch_size=args.micro_batch_size,
        precision=args.precision,
        tensor_model_parallel_size=args.tensor_model_parallel_size,
        pipeline_model_parallel_size=args.pipeline_model_parallel_size,
        devices=args.num_gpus,
        num_nodes=args.num_nodes,
        config_class=args.config_class,
    )

infer_model(data_path, checkpoint_path, results_path, min_seq_length=1024, include_hiddens=False, include_embeddings=False, include_logits=False, include_input_ids=False, micro_batch_size=64, precision='bf16-mixed', tensor_model_parallel_size=1, pipeline_model_parallel_size=1, devices=1, num_nodes=1, prediction_interval='epoch', config_class=ESM2Config)

使用 PyTorch Lightning 在 BioNeMo ESM2 模型上运行推理。

参数

名称 类型 描述 默认值
data_path 路径

输入数据的路径。

必需
checkpoint_path 路径

模型检查点的路径。

必需
results_path 路径

保存推理结果的路径。

必需
min_seq_length int

要填充的最小序列长度。这应至少等于数据集中最长序列的长度

1024
include_hiddens bool

是否在输出中包含隐藏状态。默认为 False。

False
include_embeddings bool

是否在输出中包含嵌入。默认为 False。

False
include_logits (bool, 可选)

是否在输出中包含令牌 logits。默认为 False。

False
include_input_ids (bool, 可选)

是否在输出中包含 input_ids。默认为 False。

False
micro_batch_size int

用于推理的微批大小。默认为 64。

64
precision PrecisionTypes

用于推理的精度类型。默认为 "bf16-mixed"。

'bf16-mixed'
tensor_model_parallel_size int

用于分布式推理的张量模型并行大小。默认为 1。

1
pipeline_model_parallel_size int

用于分布式推理的流水线模型并行大小。默认为 1。

1
devices int

用于推理的设备数量。默认为 1。

1
num_nodes int

用于分布式推理的节点数量。默认为 1。

1
prediction_interval IntervalT

将 predict 方法输出写入磁盘以进行 DDP 推理的间隔。默认为 epoch。

'epoch'
config_class Type[BioBertConfig]

用于使用提供的检查点配置模型的配置类

ESM2Config
源代码位于 bionemo/esm2/scripts/infer_esm2.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def infer_model(
    data_path: Path,
    checkpoint_path: Path,
    results_path: Path,
    min_seq_length: int = 1024,
    include_hiddens: bool = False,
    include_embeddings: bool = False,
    include_logits: bool = False,
    include_input_ids: bool = False,
    micro_batch_size: int = 64,
    precision: PrecisionTypes = "bf16-mixed",
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
    devices: int = 1,
    num_nodes: int = 1,
    prediction_interval: IntervalT = "epoch",
    config_class: Type[BioBertConfig] = ESM2Config,
) -> None:
    """Runs inference on a BioNeMo ESM2 model using PyTorch Lightning.

    Args:
        data_path (Path): Path to the input data.
        checkpoint_path (Path): Path to the model checkpoint.
        results_path (Path): Path to save the inference results.
        min_seq_length (int): minimum sequence length to be padded. This should be at least equal to the length of largest sequence in the dataset
        include_hiddens (bool, optional): Whether to include hidden states in the output. Defaults to False.
        include_embeddings (bool, optional): Whether to include embeddings in the output. Defaults to False.
        include_logits (bool, Optional): Whether to include token logits in the output. Defaults to False.
        include_input_ids (bool, Optional): Whether to include input_ids in the output. Defaults to False.
        micro_batch_size (int, optional): Micro batch size for inference. Defaults to 64.
        precision (PrecisionTypes, optional): Precision type for inference. Defaults to "bf16-mixed".
        tensor_model_parallel_size (int, optional): Tensor model parallel size for distributed inference. Defaults to 1.
        pipeline_model_parallel_size (int, optional): Pipeline model parallel size for distributed inference. Defaults to 1.
        devices (int, optional): Number of devices to use for inference. Defaults to 1.
        num_nodes (int, optional): Number of nodes to use for distributed inference. Defaults to 1.
        prediction_interval (IntervalT, optional): Intervals to write predict method output into disck for DDP inference. Defaults to epoch.
        config_class (Type[BioBertConfig]): The config class for configuring the model using checkpoint provided
    """
    # create the directory to save the inference results
    os.makedirs(results_path, exist_ok=True)

    # Setup the strategy and trainer
    global_batch_size = infer_global_batch_size(
        micro_batch_size=micro_batch_size,
        num_nodes=num_nodes,
        devices=devices,
        tensor_model_parallel_size=tensor_model_parallel_size,
        pipeline_model_parallel_size=pipeline_model_parallel_size,
    )

    strategy = nl.MegatronStrategy(
        tensor_model_parallel_size=tensor_model_parallel_size,
        pipeline_model_parallel_size=pipeline_model_parallel_size,
        ddp="megatron",
        find_unused_parameters=True,
    )

    prediction_writer = PredictionWriter(output_dir=results_path, write_interval=prediction_interval)

    trainer = nl.Trainer(
        accelerator="gpu",
        devices=devices,
        strategy=strategy,
        num_nodes=num_nodes,
        callbacks=[prediction_writer],
        plugins=nl.MegatronMixedPrecision(precision=precision),
    )

    dataset = InMemoryCSVDataset(data_path=data_path)
    datamodule = ESM2FineTuneDataModule(
        predict_dataset=dataset,
        micro_batch_size=micro_batch_size,
        global_batch_size=global_batch_size,
        min_seq_length=min_seq_length,
    )

    config = config_class(
        params_dtype=get_autocast_dtype(precision),
        pipeline_dtype=get_autocast_dtype(precision),
        autocast_dtype=get_autocast_dtype(precision),
        include_hiddens=include_hiddens,
        include_embeddings=include_embeddings,
        include_input_ids=include_input_ids,
        skip_logits=not include_logits,
        tensor_model_parallel_size=tensor_model_parallel_size,
        pipeline_model_parallel_size=pipeline_model_parallel_size,
        initial_ckpt_path=str(checkpoint_path),
        initial_ckpt_skip_keys_with_these_prefixes=[],  # load everything from the checkpoint.
    )

    tokenizer = get_tokenizer()
    module = biobert_lightning_module(config=config, tokenizer=tokenizer)

    # datamodule is responsible for transforming dataloaders by adding MegatronDataSampler. Alternatively, to
    # directly use dataloader in predict method, the data sampler should be included in MegatronStrategy
    trainer.predict(module, datamodule=datamodule)  # return_predictions=False failing due to a lightning bug