使用 ESM-2 进行零样本蛋白质设计¶
我们感谢 A-Alpha Bio 的 Adrian Lange 最初贡献了这个配方。 此笔记本自那时起已由 NVIDIA 修改。
演示目标¶
- ESM-2nv 推理功能
- 目标:对预训练的 ESM-2 模型执行推理。
- 步骤:下载模型检查点,创建蛋白质序列的 CSV 数据文件,并从输入的蛋白质序列生成隐藏状态表示和序列嵌入。
- Logit 和概率提取
- 目标:获取氨基酸序列中每个位置所有可能 token 的概率值。
- 步骤:从隐藏状态生成 logits,并将它们转换为概率。
- 蛋白质突变体设计
- 目标:优化输入的蛋白质序列,使其更紧密地与自然产生的蛋白质变体对齐。
- 步骤:顺序掩码氨基酸,提取每个位置的概率(并创建热图),分析单点突变体比野生型具有更高可能性的位置,并开发新的候选物。
背景¶
ESM-2 是一个大规模蛋白质语言模型 (PLM),在数百万个蛋白质序列上训练。 它可以捕获蛋白质序列中复杂的模式和关系,使其可用于预测不同位置可能的氨基酸替换。 通过利用 ESM-2 的掩码语言建模 (MLM) 功能,我们可以识别可能增强蛋白质特性或使其更紧密地与自然产生的变体对齐的潜在突变。 ESM-2 有 650M 和 3B 参数版本 - 对于本演示,我们将使用 ESM-2 3B。
设置¶
此笔记本应在 BioNeMo Docker 容器内执行,该容器已预安装所有 ESM-2 依赖项。 本教程假设 BioNeMo 框架仓库的副本存在于工作站或服务器上,并且已挂载在容器内的 /workspace/bionemo2
。 有关如何构建或拉取 BioNeMo2 容器的更多信息,请参阅初始化指南。
%%capture --no-display --no-stderr cell_output来抑制此输出。 在下面的单元格中注释或删除此行以恢复完整输出。
导入所需的库¶
%%capture --no-display --no-stderr cell_output
import os
import torch
import shutil
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
工作目录¶
设置工作目录以存储数据和结果
cleanup : bool = True
cleanup : bool = True
work_dir="/workspace/bionemo2/esm2_mutant_design_tutorial"
if cleanup and os.path.exists(work_dir):
shutil.rmtree(work_dir)
if not os.path.exists(work_dir):
os.makedirs(work_dir)
print(f"Directory '{work_dir}' created.")
else:
print(f"Directory '{work_dir}' already exists.")
Directory '/workspace/bionemo2/esm2_mutant_design_tutorial' created.
下载模型检查点¶
以下代码将从 NGC 注册表下载预训练模型
checkpoint = "esm2/3b:2.0"
from bionemo.core.data.load import load
checkpoint = "esm2/650m:2.0" # change to "esm2/3b:2.0" to use the ESM-2 3B model
checkpoint_path = load(checkpoint, source="ngc")
ESM-2 推理¶
在本节中,我们将探索预训练模型的关键推理功能。
数据¶
在第一步中,我们通过创建一个包含 sequences
列的 CSV 文件来准备数据,该列保存我们用作推理输入的蛋白质序列。
import pandas as pd
sequences = [
'MSLKRKNIALIPAAGIGVRFGADKPKQYVEIGSKTVLEHVL', # length: 41
'MIQSQINRNIRLDLADAILLSKAKKDLSFAEIADGTGLA', # length: 39
]
# Create a DataFrame
df = pd.DataFrame(sequences, columns=["sequences"])
# Save the DataFrame to a CSV file
data_path = os.path.join(work_dir, "sequences.csv")
df.to_csv(data_path, index=False)
分词器¶
让我们也检查一下分词器词汇表。
from bionemo.esm2.data.tokenizer import get_tokenizer, BioNeMoESMTokenizer
tokenizer = get_tokenizer()
tokens = tokenizer.all_tokens
print(f"There are {tokenizer.vocab_size} unique tokens: {tokens}.")
There are 33 unique tokens: ['<cls>', '<pad>', '<eos>', '<unk>', 'L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', '.', '-', '<null_1>', '<mask>'].
让我们把对应于 20 种已知氨基酸的 token 放在一边。
aa_tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C']
aa_indices = [i for i, token in enumerate(tokens) if token in aa_tokens]
extra_indices = [i for i, token in enumerate(tokens) if token not in aa_tokens]
获取模型输出¶
ESM-2nv 是使用掩码语言建模 (MLM) 目标训练的。 因此,我们能够掩码氨基酸序列中的一个位置,并根据周围的上下文获得该位置最可能的氨基酸的值。 让我们依次获得序列中每个位置的这些值。
隐藏状态(通常是神经网络中每一层的输出)可以通过在使用 BioNeMo 框架中 ESM-2 的推理函数时使用 --include-hiddens
参数来获得。
隐藏状态可以转换为固定大小的向量嵌入。 这是通过删除与填充 token 对应的隐藏状态向量,然后对剩余部分进行平均来实现的。 当目标是从模型的隐藏状态创建单个向量表示时,通常使用此过程,该向量表示可用于各种序列级下游任务,例如分类(例如亚细胞定位)或回归(例如熔解温度预测)。 为了获得嵌入结果,我们可以使用 --include-embeddings
参数。
通过将氨基酸序列的隐藏状态传递到 BERT 语言模型头,我们可以获得每个位置的输出 logits,并将它们转换为概率。 这可以通过使用 --include-logits
参数来实现。 这里的 Logits 是原始的、未归一化的分数,表示每个类别的可能性,而不是概率本身;它们可以是任何实数,包括负值。
当我们对 logits 应用 softmax 函数时,它会将它们转换为类别的概率分布,其中概率之和等于 1。
现在,让我们使用相关参数调用 infer_esm2
可执行文件,以计算并可选地返回嵌入、隐藏状态和 logits。
%%capture --no-display --no-stderr cell_output
example_dir = os.path.join(work_dir, "inference_example")
os.makedirs(example_dir, exist_ok=True)
! infer_esm2 --checkpoint-path {checkpoint_path} \
--data-path {data_path} \
--results-path {example_dir} \
--num-gpus 1 \
--precision "bf16-mixed" \
--include-hiddens \
--include-embeddings \
--include-logits \
--include-input-ids
这将把 ESM-2 推理的输出写入 python 字典,并将其保存到 predictions__rank_0.pt
中,可以通过 PyTorch 加载。 BioNeMo 框架中支持 DDP 推理,可以通过设置 --num-gpus n
来使用 n
个设备来利用它。 然后,输出预测将写入 n 个不同的文件 predictions__rank_<0...n-1>.pt
。 有关 DDP 支持以及如何解释预测输出的更多信息,请参阅ESM-2 推理教程。
results = torch.load(f"{example_dir}/predictions__rank_0.pt")
for key, val in results.items():
if val is not None:
print(f'{key}\t{val.shape}')
token_logits torch.Size([1024, 2, 128]) hidden_states torch.Size([2, 1024, 1280]) input_ids torch.Size([2, 1024]) embeddings torch.Size([2, 1280])
Logits (token_logits
) 张量的维度为 [sequence, batch, hidden]
,以提高训练性能。 我们将在下面转置前两个维度,使其具有像其余输出张量一样的批优先形状。
logits = results['token_logits'].transpose(0, 1) # s, b, h -> b, s, h
print(logits.shape)
torch.Size([2, 1024, 128])
toke_logits
的序列维度为 1024,其中包括序列开始、序列结束 (eos/bos) 和填充。 token_logits
的最后一个维度为 128,其中前 33 个位置对应于氨基酸词汇表,后跟 95 个填充。 我们使用 tokenizer.vocab_size
来过滤掉填充,并且仅保留 33 个词汇位置。
aa_logits = logits[..., :tokenizer.vocab_size] # filter out the 95 paddings and only keep 33 vocab positions
print(aa_logits.shape)
torch.Size([2, 1024, 33])
我们将通过在 -inf
上调用 softmax 来强制非氨基酸 token 的概率变为零。 这些 token ID 列为 extra_indices
,我们将 logits 值设置为 -inf
。
现在我们可以使用 PyTorch Softmax 函数将 logits 转换为概率。
aa_logits[..., extra_indices] = - torch.inf # force non-amino acid token probs to zero
probs = torch.softmax(aa_logits, dim=-1)
# check that rows sum to 1
# probs.sum(dim=-1)
这些步骤在下面的 logits_to_probs()
函数中进行了总结
def logits_to_probs(
logits: torch.Tensor, tokenizer: BioNeMoESMTokenizer = get_tokenizer()
) -> torch.Tensor:
"""Convert token logits to probabilities
Args:
logits (torch.Tensor): logits tensor with the [batch, sequence, hidden] dimensions
tokenizer (BioNeMoESMTokenizer): ESM2 tokenizer
Returns:
probabilities (torch.Tensor): probability tensor with [batch, sequence, tokenizer.vocab_size]
"""
aa_tokens = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C']
extra_indices = [i for i, token in enumerate(tokenizer.all_tokens) if token not in aa_tokens]
aa_logits = logits[..., :tokenizer.vocab_size] # filter out the 95 paddings and only keep 33 vocab positions
aa_logits[..., extra_indices] = - torch.inf # force non-amino acid token probs to zero
return torch.softmax(aa_logits, dim=-1)
注意¶
此示例中的序列维度 (1024) 表示最大序列长度,其中包括填充、EOS 和 BOS。 为了过滤相关的氨基酸信息,我们可以使用结果中的输入序列 ID 来创建掩码
input_ids = results['input_ids'] # b, s
# mask where non-amino acid tokens are True
mask = torch.isin(input_ids, torch.tensor(extra_indices))
通过 ESM-2nv 进行突变体设计¶
在本节中,我们旨在通过引入单点突变来优化输入蛋白质序列,从而使其更紧密地与自然产生的蛋白质变体对齐。 这些突变体可能表现出增强蛋白质功能的特性,例如改善的稳定性或增加的催化活性。 通过利用 ESM-2 的掩码语言建模功能,我们可以识别出比野生型残基具有更高可能性的氨基酸替换。 这种方法使我们能够有效地探索蛋白质序列空间,并有可能发现具有卓越特性的变体。
顺序掩码¶
让我们取一个起始序列并扫描各个位置,迭代地将 <mask>
token 放在每个位置现有氨基酸的位置。 然后,我们将预测每个掩码位置的概率。 如果您只想分析序列的预定义部分(例如,特定的 α 螺旋)内的替换,则可以在下面相应地设置 start_pos
和 end_pos
。
seq = 'MSLKRKNIALIPAAGIGVRFGADKPKQYVEIGSKTVLEHVL' # length: 41
start_pos = 0
end_pos = len(seq)
positions = np.arange(start_pos, end_pos)
sequentially_masked = list()
for index in positions:
masked = seq[:index] + "<mask>" + seq[index+1:]
sequentially_masked.append(masked)
让我们将掩码序列保存到 CSV 文件中,并查看 sequentially_masked_sequences
的前几个元素
# Create a DataFrame
df = pd.DataFrame(sequentially_masked, columns=["sequences"])
# Save the DataFrame to a CSV file
masked_data_path = os.path.join(work_dir, "sequentially_masked_sequences.csv")
df.to_csv(masked_data_path, index=False)
df.head(n=5)
sequences | |
---|---|
0 | <mask>SLKRKNIALIPAAGIGVRFGADKPKQYVEIGSKTVLEHVL |
1 | M<mask>LKRKNIALIPAAGIGVRFGADKPKQYVEIGSKTVLEHVL |
2 | MS<mask>KRKNIALIPAAGIGVRFGADKPKQYVEIGSKTVLEHVL |
3 | MSL<mask>RKNIALIPAAGIGVRFGADKPKQYVEIGSKTVLEHVL |
4 | MSLK<mask>KNIALIPAAGIGVRFGADKPKQYVEIGSKTVLEHVL |
概率提取¶
我们现在提取 logits 并将它们转换为 sequentially_masked
的每个元素的概率矩阵。 这可以通过使用 --include-logits
调用上面的推理函数并使用 softmax 将 logits 转换为概率来轻松完成。 然后,我们可以选择对应于掩码位置的概率向量,并将它们组合成最终的概率矩阵。
%%capture --no-display --no-stderr cell_output
! infer_esm2 --checkpoint-path {checkpoint_path} \
--data-path {masked_data_path} \
--results-path {work_dir} \
--num-gpus 1 \
--precision "bf16-mixed" \
--include-logits \
--include-input-ids
results = torch.load(f"{work_dir}/predictions__rank_0.pt")
# cast to FP32 since BFloat16 is an unsupported ScalarType in numpy
logits = results['token_logits'].transpose(0, 1).to(dtype=torch.float32) # s, b, h -> b, s, h
probs = logits_to_probs(logits)
print(probs.shape)
torch.Size([41, 1024, 33])
我们只对与氨基酸 tokens 相关的概率感兴趣。因此,我们需要忽略 padding 和 eos/bos tokens。由于所有序列的长度都相同,我们可以使用它来过滤它们
probas_final = probs[:, 1:positions.size+1, :]
probas_final.shape
torch.Size([41, 41, 33])
选择并组合与每个掩码对应的概率
probas_final = probas_final[np.arange(probas_final.shape[0]), positions, :]
print(probas_final.shape)
torch.Size([41, 33])
氨基酸热图¶
让我们可视化结果。我们可以绘制每个感兴趣位置上每个 token 的预测概率。
# Create heatmap
dat = probas_final[:, aa_indices]
plt.figure(figsize=(11, 5))
im = plt.imshow(dat.T, cmap='viridis', aspect='auto')
# Add color scale
cbar = plt.colorbar(im)
cbar.set_label('Probability', rotation=270, labelpad=15)
# Set y-axis labels (amino acid tokens) and x-axis labels (position in sequence)
plt.yticks(ticks=np.arange(len(aa_tokens)), labels=aa_tokens)
plt.xticks(ticks=np.arange(dat.shape[0]), labels=list(seq))
plt.gca().xaxis.set_ticks_position('bottom')
# Add axes titles and main title
plt.xlabel('Position in Sequence')
plt.ylabel('Token Labels')
plt.title('Positional Token Probabilities')
# Adjust layout to prevent clipping of labels
plt.tight_layout()
plt.show()
突变体发现¶
我们现在可以将 logits/概率转换回序列空间,方法是将每个位置的最高概率映射到相应的氨基酸。
# Predicted seq (Argmax --> Collect token IDs of predicted seq --> Convert to amino acids)
pred_idx_list = np.argmax(probas_final, axis=-1).tolist()
pred_seq = "".join([tokenizer.id_to_token(id) for id in pred_idx_list])
# Original seq
true_idx_list = [tokenizer.token_to_id(seq[i]) for i in positions]
true_seq = "".join([tokenizer.id_to_token(id) for id in true_idx_list])
让我们比较序列,并直观地检查建议使用突变体而不是野生型的位置。请注意,预测序列显示在顶部,原始序列显示在底部。
# Compare prediction (reconstruction) to true (input sequence)
display(pred_seq + " (Predicted Sequence)")
display(
"".join(
["." if a == b else "|" for a, b in zip(pred_seq, true_seq)]
)
)
display(true_seq + " (Input Sequence)")
'MSEKKKVVALILAAGKGSRLGAGRPKQFLKIGGKTILERTL (Predicted Sequence)'
'..|.|.||...|...|.|.|..||...|||..|..|..||.'
'MSLKRKNIALIPAAGIGVRFGADKPKQYVEIGSKTVLEHVL (Input Sequence)'
在不匹配项中,我们可以
- 收集所有建议使用突变体而不是野生型氨基酸的位置。
- 在这些位置,找到具有最高概率的突变体。
# Collect indices where a mutant is suggested over the wild-type
matches = [c1 == c2 for c1, c2 in zip(pred_seq, true_seq)]
mismatch_index = [i for i, value in enumerate(matches) if not value]
# Filter probability matrix to mismatches-only
probas_mismatch = probas_final[mismatch_index, :]
# Find index of mutant with highest likelihood
index_flat = np.argmax(probas_mismatch)
index_2d = np.unravel_index(index_flat, probas_mismatch.shape)
index_of_interest = mismatch_index[index_2d[0]]
position_of_interest = positions[index_of_interest]
print("Position:", position_of_interest)
print("Mutation:", true_seq[position_of_interest] + str(position_of_interest) + pred_seq[position_of_interest])
Position: 32 Mutation: S32G
让我们检查与此位置的突变相关的概率。
# Sort tokens by probability
token_ids_sort = sorted(enumerate(probas_final[index_of_interest]), key=lambda x: x[1], reverse=True)
tokens_sort = [(tokenizer.all_tokens[i], i, p.item()) for i, p in token_ids_sort]
tokens_sort_df = pd.DataFrame(tokens_sort, columns=['Token', 'Token ID', 'Probability'])
tokens_sort_df.head()
Token | Token ID | 概率 | |
---|---|---|---|
0 | G | 6 | 0.827384 |
1 | D | 13 | 0.055431 |
2 | E | 9 | 0.032586 |
3 | N | 17 | 0.030137 |
4 | S | 8 | 0.018279 |
很明显,对于这个位置,氨基酸甘氨酸 (G) 比野生型丝氨酸 (S) 具有更高的可能性。通过这种方式,我们可以使用 ESM-2nv 来设计用于下游测试的新型突变体候选物。
我们可以通过多种方式从 ESM-2nv 输出中设计候选物。我们可以继续寻找前 n 个单点突变体,找到前 n 个双点或多点突变体,在输入序列生成的概率空间中随机抽样,仅在某些感兴趣的位置(例如,已知的活性位点)内抽样等等。通过此过程,可以开发一组突变体用于进一步的计算机模拟或湿实验室测试。