跳到内容

批量增强

BatchAugmentation

基于指定的最佳传输类型,促进批量增强对象的创建。

参数

名称 类型 描述 默认值
device str

用于计算的设备 (例如, 'cpu', 'cuda')。

必需
num_threads int

要利用的线程数。

必需
源代码在 bionemo/moco/interpolants/batch_augmentation.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class BatchAugmentation:
    """Facilitates the creation of batch augmentation objects based on specified optimal transport types.

    Args:
        device (str): The device to use for computations (e.g., 'cpu', 'cuda').
        num_threads (int): The number of threads to utilize.
    """

    def __init__(self, device, num_threads):
        """Initializes a BatchAugmentation instance.

        Args:
            device (str): Device for computation.
            num_threads (int): Number of threads to use.
        """
        self.device = device
        self.num_threads = num_threads

    def create(self, method_type: OptimalTransportType):
        """Creates a batch augmentation object of the specified type.

        Args:
            method_type (OptimalTransportType): The type of optimal transport method.

        Returns:
            The augmentation object if the type is supported, otherwise **None**.
        """
        if method_type == OptimalTransportType.EXACT:
            augmentation = OTSampler(method="exact", device=self.device, num_threads=self.num_threads)
        elif method_type == OptimalTransportType.KABSCH:
            augmentation = KabschAugmentation()
        elif method_type == OptimalTransportType.EQUIVARIANT:
            augmentation = EquivariantOTSampler(method="exact", device=self.device, num_threads=self.num_threads)
        else:
            return None
        return augmentation

__init__(device, num_threads)

初始化 BatchAugmentation 实例。

参数

名称 类型 描述 默认值
device str

用于计算的设备。

必需
num_threads int

要使用的线程数。

必需
源代码在 bionemo/moco/interpolants/batch_augmentation.py
35
36
37
38
39
40
41
42
43
def __init__(self, device, num_threads):
    """Initializes a BatchAugmentation instance.

    Args:
        device (str): Device for computation.
        num_threads (int): Number of threads to use.
    """
    self.device = device
    self.num_threads = num_threads

create(method_type)

创建指定类型的批量增强对象。

参数

名称 类型 描述 默认值
method_type OptimalTransportType

最优传输方法的类型。

必需

返回

类型 描述

如果类型受支持,则返回增强对象,否则返回 None

源代码在 bionemo/moco/interpolants/batch_augmentation.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def create(self, method_type: OptimalTransportType):
    """Creates a batch augmentation object of the specified type.

    Args:
        method_type (OptimalTransportType): The type of optimal transport method.

    Returns:
        The augmentation object if the type is supported, otherwise **None**.
    """
    if method_type == OptimalTransportType.EXACT:
        augmentation = OTSampler(method="exact", device=self.device, num_threads=self.num_threads)
    elif method_type == OptimalTransportType.KABSCH:
        augmentation = KabschAugmentation()
    elif method_type == OptimalTransportType.EQUIVARIANT:
        augmentation = EquivariantOTSampler(method="exact", device=self.device, num_threads=self.num_threads)
    else:
        return None
    return augmentation