跳到内容

Geneformer

BioNeMo1 中训练的当前检查点

本文档引用的性能数字和运行时引擎来自模型的 bionemo v1 变体。这些数字将在即将发布的版本中更新,以反映新的 bionemo v2 代码库。模型架构和训练信息将保持不变,因为检查点是从 bionemo v1 格式转换为 v2 格式的。以下基准测试已注释了生成它们的 bionemo 版本。由于我们有测试表明两个版本之间的模型等效性,因此准确率应在很小的 epsilon 范围内保持一致。

模型概述

描述

Geneformer 通过学习单细胞内的共表达模式,生成 sc-RNA 细胞的密集表示。Geneformer 是一个表格计数模型,在来自 Chan Zuckerberg Cell x Gene census 的 sc-RNA 上进行训练。Geneformer 计算每个细胞在前 1024 个表达基因上的完整嵌入。这些嵌入被用作各种预测任务的特征。此模型已准备好用于商业和学术用途。

参考文献

模型架构

架构类型: 来自 Transformer 的双向编码器表示 (BERT)
网络架构: Geneformer

输入

输入类型: 数字(行代表细胞,包含基因名称和单细胞表达计数)
输入格式: 数组 AnnData
输入参数: 1D

输出

输出类型: 向量(密集嵌入预测)embeddings。
输出格式: NumPy
输出参数: 1D
与输出相关的其他属性: 数值浮点向量(fp16、bf16 或 fp32);geneformer-10M-240530 输出 256 维嵌入;geneformer-106M-240530 输出 768 维嵌入

软件集成

运行时引擎

  • BioNeMo,NeMo 1.2

支持的硬件微架构兼容性

  • Ampere
  • Hopper
  • Volta

[首选/支持] 操作系统

  • Linux

模型版本

  • geneformer-10M-240530
  • 1030 万参数 geneformer 变体。
  • 25429 基于集成 ID 的基因 Token
  • 256 个隐藏维度,带有 4 个头、6 层和一个 512 维 FFN
  • relu 激活
  • 1e-12 EPS layernorm
  • bf16 混合精度训练,带有 32 位残差连接
  • 2% 隐藏 dropout,10% 注意力 dropout
  • geneformer-106M-240530
  • 1.06 亿参数 geneformer 变体。
  • 25429 基于集成 ID 的基因 Token
  • 768 个隐藏维度,带有 12 个头、12 层和一个 3072 维 FFN
  • relu 激活
  • 1e-12 EPS layernorm
  • bf16 混合精度训练,带有 32 位残差连接
  • 2% 隐藏 dropout,10% 注意力 dropout

训练与评估

训练数据集

来自 CELLxGENE Census 的单细胞表达计数用于直接下载数据,这些数据匹配 geneformer 出版物中描述的类似标准。将细胞数据限制为 organism="Homo sapiens",具有非 "na" suspension_type,is_primary_data=True 和 disease="normal",以限制为非疾病组织,这些组织也是每个细胞的主要数据来源,以确保细胞仅在下载中包含一次。我们跟踪了元数据,包括 "assay"、"sex"、"development_stage"、"tissue_general"、"dataset_id" 和 "self_reported_ethnicity"。元数据 "assay"、"tissue_general" 和 "dataset_id" 用于构建数据集拆分为训练集、验证集和测试集。

训练集代表下载细胞的 99%。我们按 dataset_id 将数据划分为训练集 (99%) 和保留集 (1%),以确保保留数据集是独立收集的单细胞实验,这有助于评估对新未来数据集的泛化能力。

在此训练拆分中,我们确保所有 "assay" 和 "tissue_general" 标签都存在于训练集中,以便我们的模型能够最大限度地了解不同的组织和检测偏差。

1% 的保留评估集进一步拆分为验证集和测试集。此最终拆分主要是通过细胞随机完成的;但是,我们将一个完整的数据集放在测试拆分中,以便我们可以在训练后评估完全未见数据集上的性能,包括在训练期间监控验证损失时。

链接:CZ CELLxGENE Discover - Cellular Visualization Tool (cziscience.com) 下载的数据集
** 按数据集的数据收集方法

  • [人类]

** 按数据集的标记方法

  • 混合:自动化,人工

属性(数量、数据集描述、传感器): 从 CZI CELLxGENE census 中选择了 2364 万个非疾病和人类来源的单细胞,其特征如下

  • 检测偏差
  • 绝大多数数据集是 10x genomics 检测之一,大约 2600 万个细胞中的 2000 万个是基因组检测,400 万个是 sci-RNA-seq,而剩余的检测(microwell-seq、drop-seq、bd rhapsody、smart-seq、seq-well 和 MARS-seq)仅占完整数据集的很小一部分。
  • 性别
  • 1250 万个是男性来源的细胞;1000 万个是女性来源的细胞。其余细胞未注释。
  • 自我报告的种族
  • 大约 1200 万个细胞未注释;900 万个被注释为“欧洲人”。 50 万个被注释为“汉族”,其次是“非裔美国人”。
  • 年龄偏差
  • 数据集严重偏向于不到一岁的捐赠者。下一个最高的群体将是包括 21-30 岁年龄段的群体。
  • 组织类型偏差
  • 900 万个细胞来自“大脑”。400 万个来自血液,其次是“肺”、“乳腺”、“心脏”和“眼睛”,每个大约 100 万个细胞。

数据集来源于有限数量的公共来源,这些来源的方法和协议可能无法充分代表多样化的来源以捕获基因表达的完整范围。

评估数据集

Adamson 等人 2016 PERTURB-seq 数据集,通过 Harvard dataverse 访问。链接: adamson.zip - Harvard Dataverse
** 按数据集的数据收集方法

  • 人类

** 按数据集的标记方法

  • 自动化 - 分子条形码

属性(数量、数据集描述、传感器): 大约有 2 万个单细胞,其中一半代表未扰动的对照样本,另一半包含一个额外的数据表,其中包含每个细胞的 CRISPR 敲除靶标。

链接: CZ CELLxGENE Discover - Cellular Visualization Tool (cziscience.com)
** 按数据集的数据收集方法

  • 人类

** 按数据集的标记方法

  • 混合:自动化,人工

属性(数量、数据集描述、传感器)

  • 从 CZI cell x gene census 中选择了 24 万个单细胞,使得它们与先前描述的训练数据中的任何细胞都不共享 dataset_id

推理

引擎: BioNeMo,NeMo
测试硬件

  • Ampere
  • Hopper
  • Volta

*此处可能包含其他描述内容

伦理考量

NVIDIA 认为值得信赖的 AI 是一项共同责任,我们已建立政策和实践,以支持各种 AI 应用的开发。当按照我们的服务条款下载或使用时,开发人员应与其内部团队合作,以确保此模型满足相关行业和用例的要求,并解决不可预见的产品误用问题。有关此模型的伦理考量的更多详细信息,请参阅 Model Card++ 可解释性、偏差、安全性和隐私子卡 [在此处插入 Model Card++ 的链接]。请在此处报告安全漏洞或 NVIDIA AI 关注问题 here

训练诊断

geneformer-10M-240530

此检查点通过 CELLxGENE 拆分训练了大约 11 个 epoch。训练在 8 台服务器上执行,每台服务器有 8 个 A100 GPU,总共 115430 步,每个 GPU 的微批量大小为 32,全局批量大小为 2048。训练总共花费了 1 天 20 小时 19 分钟的挂钟时间。从下图可以看出,训练曲线和验证曲线在整个训练过程中都相当平稳地下降。事实上,在数据集的 11 个 epoch 结束时,验证(蓝色)和训练(橙色)损失都在持续下降。该模型可能会在不过拟合的情况下训练更多 epoch。验证和训练损失在整个训练过程中都平稳下降

来自 BioNeMo1 的训练曲线

请注意,这些曲线是在 BioNeMo1 上生成的。但是,我们在 BioNeMo2 的初始测试中看到了相同的总体训练曲线。在下图中,蓝线是 10M 模型的先前训练运行,红线是 BioNeMo2 上等效的训练运行。当我们发布新的检查点时,它们将在 BioNeMo2 上进行训练。

Training curve equivalence

geneformer-106M-240530

此检查点通过 CELLxGENE 拆分训练了大约 11 个 epoch。训练在 16 台服务器上执行,每台服务器有 8 个 A100 GPU,总共 115430 步,每个 GPU 的微批量大小为 16,全局批量大小为 2048。训练总共花费了 3 天 18 小时 55 分钟的挂钟时间。从下图可以看出,训练曲线和验证曲线在整个训练过程中都相当平稳地下降。事实上,在数据集的 11 个 epoch 结束时,验证(蓝色)和训练(橙色)损失都在持续下降。该模型可能会在不过拟合的情况下训练更多 epoch。验证和训练损失在整个训练过程中都平稳下降

此外,与 10M 参数模型(蓝色)相比,106M 参数模型(红色)中的验证损失下降得更快,并且在整个训练过程中继续以相同的改进率下降。测试更大的模型以查看我们是否继续观察到更大模型中性能的提高将很有趣。106M 参数模型优于 10M 参数模型

!! note "来自 BioNeMo1 的训练曲线"

As stated in the previous section, the figures are from our BioNeMo1 code base where these checkpoints were originally
trained. As we release new checkpoints they will be trained on BioNeMo2.

基准测试

准确率基准

掩码语言模型 (MLM) 损失

以下描述了 bert MLM Token 损失。与原始 BERT 论文和 geneformer 论文一样,所有 Token 的 15% 都包含在损失中。在包含的 Token 中,80% 是 "[MASK]" Token,2% 是随机基因 Token,18% 是正确的输出 Token。请注意,这是与原始出版物无意的偏差,但到目前为止,它似乎运行良好。将来,我们将测试论文中提出的预期 80%/10%/10% 混合。下表中的 Token 损失是损失掩码中包含的 15% 的 Token 的平均交叉熵损失,该损失在细胞之间平均。作为基线,geneformer 是从 ctheodoris/Geneformer 页面在 hugging face 上于 2024/11/04 下载 并应用于此数据集上的相同掩码/取消掩码问题,但由于更新的分词器和用于训练的中位数词典,以及从每个细胞训练 2048 个 Token 更新到 4096 个 Token,因此具有特定于模型的细胞表示。使用了我们先前描述的训练拆分中的保留 test 数据集,应该注意的是,其中一些细胞可能参与了基线 geneformer 的训练。

模型描述 Token 损失(越低越好)
基线 geneformer 2.26*
geneformer-10M-240530 2.64
geneformer-106M-240530 2.34

基线 Geneformer 最近在 huggingface 上更新,使得损失比较具有挑战性。

Geneformer 最近在 hugging face 上更新到了新版本。在未来的版本中,我们将提供检查点转换脚本,以便可以直接运行公共模型。以下是一些主要差异

  • 在更大的 9500 万个细胞数据集上训练。我们当前的检查点使用 2300 万个细胞进行训练。
  • 新的 12 层基线 geneformer 变体在参数计数方面介于我们的 10M 和 106M 参数模型之间,大约有 3800 万个参数。
  • 该模型使用 4096 上下文而不是 2048 上下文进行训练。当强制模型使用 2048 上下文进行预测时,MLM 损失降至 2.76,这可能是不公平的,因为这可能“超出训练领域”。唯一需要注意的是,直接比较这些损失数字确实很困难。
  • 该模型在一组 20,275 个基因上进行训练,而不是较旧的一组 25,426 个基因。由于可供选择的 Token 较少,预计这也将提高损失。

下游任务准确率

在这里,我们对四个模型进行基准测试,其中两个是基线。这些模型的任务是细胞类型分类,使用来自 Elmentaite 等人的 Chron's disease 小肠数据集。(2020),Developmental Cell。该数据集包含来自 4-13 岁健康儿童和患有 Chron 病儿童的大约 22,500 个单细胞。该数据集包含 31 种独特的细胞类型,我们假设这些细胞类型被准确注释。由于所有患病样本都被移除,因此该数据集从我们的预训练数据集中排除。

  • 基线 1) scRNA 工作流程:此模型使用 PCA 和 10 个组件以及归一化和对数转换表达计数的随机森林来生成结果。
  • 基线 2) geneformer-qa,一个使用大约随机权重训练约 100 步的模型。我们预计该模型的性能与直接处理计数没有区别。
  • geneformer-10M-240530 和 geneformer-106M-240530 如上所述。

有关更多详细信息,请参阅名为 Geneformer-celltype-classification-example.ipynb 的示例笔记本

F1-score for both released models, a random baseline, and a PCA based transformation of the raw expression. Average accuracy across cell types for both released models, a random baseline, and a PCA based transformation of the raw expression.

性能基准

Geneformer 的 1.06 亿参数变体在训练期间每个 GPU 实现超过 50 TFLOPS。无论使用 1 个还是 8 个 A100 进行训练,结果都是一致的。

TFLOPs per GPU (A100) shows improved utilization by 106M variant

TFLOPS 来自 BioNeMo1,BioNeMo2 中加速的早期证据

我们观察到 Geneformer 模型在 BioNeMo2 中的性能与 BioNeMo1 相当或更好。一个例子是集群运行,其中我们看到每个步骤的时间为每次迭代 0.26 秒,批处理大小为 64,通过 Geneformer。配置正确的旧 BioNeMo1 运行在批处理大小为 16 的情况下,每个步骤的时间为 0.09。当您考虑每秒样本数时,这将意味着 BioNeMo2 的速度显着提高,但这只是传闻,更彻底的比较即将到来。