Algorithm#

class nvmath.linalg.advanced.Algorithm(algorithm)[source]#

一个用于查询算法功能和配置算法的接口类。

请注意,此类型的对象不应由用户直接构造。

方法

__init__(algorithm)[source]#

属性

algorithm_id#

算法的 ID(整数)。

capabilities#

返回此算法的功能,以 nvmath.linalg.advanced.AlgorithmCapabilities 数据类的形式返回。

cluster_shape#

表示集群形状的元组 (参见 MatmulAlgoConfigAttribute.CLUSTER_SHAPE_ID)。

提供的值必须是算法功能中的 cluster_shape_ids 之一。

cta_swizzling#

指示 CTA 混洗的标志 (参见 MatmulAlgoConfigAttribute.CTA_SWIZZLING)。

只有当算法功能中的 cta_swizzling 为 1 时,才能设置此项。

custom_option#

指示自定义选项的值 (参见 MatmulAlgoConfigAttribute.CUSTOM_OPTION)。

提供的值必须小于算法功能中的 custom_option_max

inner_shape#

指示内部形状的值 (参见 MatmulAlgoConfigAttribute.INNER_SHAPE_ID)。

reduction_scheme#

使用的规约方案 (参见 MatmulAlgoConfigAttribute.REDUCTION_SCHEME)。

提供的值必须与算法功能中的 reduction_scheme_mask 一致。

split_k#

split-k 步骤的数量 (参见 MatmulAlgoConfigAttribute.SPLITK_NUM)。

只有当算法功能中的 splitk_support 为 1 时,才能设置此项。

stages#

表示阶段的元组 (参见 MatmulAlgoConfigAttribute.STAGES_ID)。提供的值必须是算法功能中的 stages_ids 之一。

tile#

表示瓦片的元组 (参见 MatmulAlgoConfigAttribute.TILE_ID)。提供的值必须是算法功能中的 tile_ids 之一。