框架特定 API
- pyTorch
Linear
GroupedLinear
LayerNorm
RMSNorm
LayerNormLinear
LayerNormMLP
DotProductAttention
MultiheadAttention
TransformerLayer
InferenceParams
CudaRNGStatesTracker
fp8_autocast()
fp8_model_init()
checkpoint()
make_graphed_callables()
get_cpu_offload_context()
moe_permute()
moe_unpermute()
initialize_ub()
destroy_ub()
- Jax