重要提示

您正在查看 NeMo 2.0 文档。此版本引入了 API 的重大更改和一个新的库,NeMo Run。我们目前正在将 NeMo 1.0 中的所有功能移植到 2.0。有关先前版本或 2.0 中尚不可用的功能的文档,请参阅NeMo 24.07 文档

注意力优化#

Flash Attention#

概述#

Flash attention 是一种旨在提高 Transformer 模型(如 GPT 和 BERT)中注意力机制效率的算法。注意力机制在序列长度上具有二次时间和内存复杂度,对于较长的序列可能会带来显著的运行时和内存挑战。

与标准的非 Flash 算法相比,Flash attention 应用了两种技术来降低内存需求并提高计算效率。

平铺技术根据共享内存大小分解输入,并一次计算一个平铺的 softmax。它不是一次处理整个查询、键和值张量,而是多次处理这些张量,然后在后续步骤中组合结果。

重计算技术存储 softmax 归一化因子(与序列长度呈线性关系),而不是 softmax 结果(与序列长度呈二次关系),并使用这些归一化因子重新计算注意力分数。这节省了写入全局内存的数据量,并减少了内存需求以及全局内存和共享内存之间的 I/O 流量。

Flash attention 将内存占用和计算复杂度从二次方降低到线性,大大扩展了大型语言模型中允许的序列长度范围。

Flash attention 算法最初在此处提出。它的两个实现是 Tri Dao等人flash-attention 和 NVIDIA cuDNN 的 融合 Flash Attention

开启和关闭 Flash Attention#

在 NeMo 框架中,Flash Attention 通过 Transformer Engine 支持,包括上述两种实现。Transformer Engine 根据输入信息(如序列长度、头数和头维度)选择合适的实现。当两种实现都适用时,Transformer Engine 在 Hopper+ 架构上优先选择 cuDNN Flash Attention,在 Ampere 架构上优先选择 Tri Dao Flash Attention。

要禁用 Tri Dao Flash Attention,请设置环境变量 NVTE_FLASH_ATTN=0。要禁用 cuDNN Flash Attention,请设置 NVTE_FUSED_ATTN=0

有关 Transformer Engine 中支持的点积注意力后端的更多详细信息,请参阅 Transformer Engine 的注意力机制 中的源代码。

多查询注意力 (MQA) 和分组查询注意力 (GQA)#

多查询注意力 (MQA)分组查询注意力 (GQA) 是 Transformer 模型中传统多头注意力机制的修改版本。这些方法提高了注意力机制的效率和有效性。

概述#

多查询注意力 (MQA)

MQA 将所有注意力头视为一个组,从而降低了计算复杂性并加快了训练时间。当模型可扩展性或有限的计算资源是问题时,它非常有用。

分组查询注意力 (GQA)

GQA 将头部分组到集群中,每个集群独立处理查询的子集。此方法平衡了传统多头注意力的细致关注和 MQA 的广泛方法,从而增强了细致的输入数据处理。

这些注意力变体提供:

  • 降低计算负载:两种方法都减少了计算量,这对于大型模型非常有利。

  • 提高处理速度:简化注意力可以加快训练和推理速度。

  • 灵活性和适应性:可以根据任务需求或硬件限制进行调整。

启用 MQA 和 GQA#

要在 NeMo 框架中使用 MQA 或 GQA,请调整模型配置中的 num_query_groups 参数

  1. 对于多查询注意力 (MQA):将 num_query_groups 设置为 1 以将所有注意力头视为一个组。

    from nemo.collections import llm
    from functools import partial
    
    # Load train recipe
    recipe = partial(llm.llama3_8b.pretrain_recipe)()
    
    recipe.model.config.num_query_groups = 1  # Enables Multi-query Attention
    
  2. 对于分组查询注意力 (GQA):

    • num_query_groups 设置为一个数字,该数字是注意力头总数的除数(大于 1 但小于总头数)。

    recipe.model.config.num_query_groups = <number_of_groups>  # Enables Grouped-query Attention
    
    • 对于常规注意力,请将此参数设置为 None 或使其与头数匹配。

    recipe.model.config.num_query_groups = None  # Default setting for regular multihead attention
    

也可以直接从 CLI 设置 num_query_groups

nemo llm pretrain --factory llama3_8b model.config.num_query_groups=8

调整 num_query_groups 以探索不同的注意力机制,并根据特定需求优化模型的性能。

实现 MQA 或 GQA#

NeMo 对 GQA 和 MQA 的支持是通过集成 Megatron Core 的注意力机制实现的。可以在 Megatron Core 的 Attention 类中探索底层实现细节,该类为这些高级注意力方法提供了功能骨干。要了解 MQA 和 GQA 的具体修改和实现,请参阅 Attention 类中的源代码

要查看 Megatron Core Repo 中 Attention 类的实现细节,请参阅 NVIDIA/Megatron-LM