跳到内容

Peft

ESM2LoRA

Bases: LoRA

用于 BioNeMo2 ESM 模型的 LoRA。

源代码位于 bionemo/esm2/model/finetune/peft.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class ESM2LoRA(LoRA):
    """LoRA for the BioNeMo2 ESM Model."""

    def __call__(self, model: nn.Module) -> nn.Module:
        """This method is called when the object is called as a function.

        Args:
            model: The input model.

        Returns:
            The modified model.
        """
        fn.walk(model, self.selective_freeze)
        fn.walk(model, self.transform)
        return model

    def selective_freeze(self, m: nn.Module, name=None, prefix=None):
        """Freezes specific modules in the given model.

        Args:
            m (nn.Module): The model to selectively freeze.
            name (str): The name of the module to freeze. Valid values are "encoder" and "embedding".
            prefix (str): The prefix of the module to freeze.

        Returns:
            nn.Module: The modified model with the specified modules frozen.

        See Also:
            nemo.collections.llm.fn.mixin.FNMixin
        """
        if name in ["encoder", "embedding"]:
            FNMixin.freeze(m)
        return m

__call__(model)

当对象作为函数调用时,将调用此方法。

参数

名称 类型 描述 默认
model 模块

输入模型。

必需

返回

类型 描述
模块

修改后的模型。

源代码位于 bionemo/esm2/model/finetune/peft.py
40
41
42
43
44
45
46
47
48
49
50
51
def __call__(self, model: nn.Module) -> nn.Module:
    """This method is called when the object is called as a function.

    Args:
        model: The input model.

    Returns:
        The modified model.
    """
    fn.walk(model, self.selective_freeze)
    fn.walk(model, self.transform)
    return model

selective_freeze(m, name=None, prefix=None)

冻结给定模型中的特定模块。

参数

名称 类型 描述 默认
m 模块

要选择性冻结的模型。

必需
名称 字符串

要冻结的模块的名称。有效值为“encoder”和“embedding”。

前缀 字符串

要冻结的模块的前缀。

返回

类型 描述

nn.Module:冻结指定模块的修改后的模型。

另请参阅

nemo.collections.llm.fn.mixin.FNMixin

源代码位于 bionemo/esm2/model/finetune/peft.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def selective_freeze(self, m: nn.Module, name=None, prefix=None):
    """Freezes specific modules in the given model.

    Args:
        m (nn.Module): The model to selectively freeze.
        name (str): The name of the module to freeze. Valid values are "encoder" and "embedding".
        prefix (str): The prefix of the module to freeze.

    Returns:
        nn.Module: The modified model with the specified modules frozen.

    See Also:
        nemo.collections.llm.fn.mixin.FNMixin
    """
    if name in ["encoder", "embedding"]:
        FNMixin.freeze(m)
    return m