嵌入
ESM2Embedding
基类:LanguageModelEmbedding
ESM2 嵌入,具有用于注意力掩码和 token dropout 的自定义逻辑。
源代码位于 bionemo/esm2/model/embedding.py
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
|
dtype: torch.dtype
property
嵌入权重的 dtype。
__init__(config, vocab_size, max_sequence_length, position_embedding_type='rope', num_tokentypes=0, token_dropout=True, use_attention_mask=True, mask_token_id=torch.nan)
初始化 ESM2 嵌入模块。
源代码位于 bionemo/esm2/model/embedding.py
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
|
forward(input_ids, position_ids, tokentype_ids=None, attention_mask=None)
嵌入模块的前向传播。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
input_ids
|
Tensor
|
输入 tokens。形状:[b, s] |
必需 |
position_ids
|
Tensor
|
用于计算位置嵌入的位置 ID。形状:[b, s] |
必需 |
tokentype_ids
|
int
|
Token 类型 ID。当 args.bert_binary_head 设置为 True 时使用。默认为 None |
None
|
attention_mask
|
Tensor
|
注意力掩码。形状:[b, s] |
None
|
返回值
名称 | 类型 | 描述 |
---|---|---|
Tensor |
Tensor
|
输出嵌入 |
源代码位于 bionemo/esm2/model/embedding.py
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
|