跳到内容

Datamodule 工具

float_or_int_or_none(value)

将给定值转换为浮点数、整数或 None。

参数

名称 类型 描述 默认值
value Union[str, float, int, None]

可以是字符串、浮点数、整数或 None 的值。

必需

返回

类型 描述
Union[float, int, None]

Union[float, int, None]:基于输入值的浮点数、整数或 None。

如果输入值为 None 或 "None",则返回 None。如果输入值为整数或浮点数,则返回相同的值。如果输入值为字符串,则尝试将其转换为整数(如果可能),否则转换为浮点数。

源代码位于 bionemo/llm/utils/datamodule_utils.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def float_or_int_or_none(value: Union[str, float, int, None]) -> Union[float, int, None]:
    """Converts a given value into a float, int, or None.

    Args:
        value (Union[str, float, int, None]): A value that can be either a string, float, int, or None.

    Returns:
        Union[float, int, None]: A float, int, or None based on the input value.

    If the input value is None or "None", it returns None.
    If the input value is an int or float, it returns the same value.
    If the input value is a string, it tries to convert it into an int if possible, otherwise into a float.
    """
    if value is None or value == "None":
        return
    if isinstance(value, (int, float)):
        return value
    if value.isdigit():
        return int(value)
    return float(value)

infer_global_batch_size(micro_batch_size, num_nodes, devices, accumulate_grad_batches=1, tensor_model_parallel_size=1, pipeline_model_parallel_size=1)

根据微批大小、节点数、设备数、梯度累积批次和模型并行大小推断全局批大小。

参数

名称 类型 描述 默认值
micro_batch_size int

微批大小。

必需
num_nodes int

节点数。

必需
devices int

设备数。

必需
accumulate_grad_batches int

梯度累积批次。默认为 1。

1
tensor_model_parallel_size int

张量模型并行大小。默认为 1。

1
pipeline_model_parallel_size int

流水线模型并行大小。默认为 1。

1

返回

名称 类型 描述
int int

全局批大小。

源代码位于 bionemo/llm/utils/datamodule_utils.py
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def infer_global_batch_size(
    micro_batch_size: int,
    num_nodes: int,
    devices: int,
    accumulate_grad_batches: int = 1,
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
) -> int:
    """Infers the global batch size based on the micro batch size, number of nodes, devices, accumulation of gradient batches, and model parallel sizes.

    Args:
        micro_batch_size (int): The micro batch size.
        num_nodes (int): The number of nodes.
        devices (int): The number of devices.
        accumulate_grad_batches (int): The accumulation of gradient batches. Defaults to 1.
        tensor_model_parallel_size (int): The tensor model parallel size. Defaults to 1.
        pipeline_model_parallel_size (int): The pipeline model parallel size. Defaults to 1.

    Returns:
        int: The global batch size.
    """
    if not all(
        isinstance(arg, int)
        for arg in [
            micro_batch_size,
            num_nodes,
            devices,
            accumulate_grad_batches,
            tensor_model_parallel_size,
            pipeline_model_parallel_size,
        ]
    ):
        raise ValueError(
            f"All arguments must be of type int, got {type(micro_batch_size)}, {type(num_nodes)}, {type(devices)}, "
            f"{type(accumulate_grad_batches)}, {type(tensor_model_parallel_size)}, and {type(pipeline_model_parallel_size)}"
        )
    if micro_batch_size <= 0:
        raise ValueError(f"micro_batch_size must be greater than 0, got {micro_batch_size}")
    if num_nodes <= 0:
        raise ValueError(f"num_nodes must be greater than 0, got {num_nodes}")
    if devices <= 0:
        raise ValueError(f"devices must be greater than 0, got {devices}")
    if accumulate_grad_batches <= 0:
        raise ValueError(f"accumulate_grad_batches must be greater than 0, got {accumulate_grad_batches}")
    if tensor_model_parallel_size <= 0:
        raise ValueError(f"tensor_model_parallel_size must be greater than 0, got {tensor_model_parallel_size}")
    if pipeline_model_parallel_size <= 0:
        raise ValueError(f"pipeline_model_parallel_size must be greater than 0, got {pipeline_model_parallel_size}")

    world_size = num_nodes * devices
    if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:
        raise ValueError(
            f"world_size must be divisible by tensor_model_parallel_size * pipeline_model_parallel_size, "
            f"got {world_size} and {tensor_model_parallel_size} * {pipeline_model_parallel_size}"
        )

    model_parallel_size = tensor_model_parallel_size * pipeline_model_parallel_size
    data_parallel_size = world_size // model_parallel_size
    global_batch_size = micro_batch_size * data_parallel_size * accumulate_grad_batches
    return global_batch_size

infer_num_samples(limit_batches, num_samples_in_dataset, global_batch_size, stage)

根据 limit_batches 参数、数据集的长度和全局批大小推断样本数。

参数

名称 类型 描述 默认值
limit_batches Union[float, int, str, None]

批次数限制。可以是 0 到 1 之间的浮点数、整数、字符串或 None。如果为 None,则默认为 1.0。

必需
num_samples_in_dataset int

数据集中的样本数。

必需
global_batch_size int

全局批大小。

必需
stage str

训练阶段。

必需

返回

名称 类型 描述
int

来自限制的样本数。

Raises

类型 描述
ValueError

如果限制的样本数小于全局批大小,或者 limit_batches 参数无效。

如果 limit_batches 是 0 到 1 之间的浮点数,则样本数被推断为数据集样本数的一部分。如果 limit_batches 是大于或等于 1 的整数,则限制的样本数被推断为 limit_batches 和全局批大小的乘积。如果 limit_batches 为 None,则默认为 1.0,表示应使用所有数据集样本。

源代码位于 bionemo/llm/utils/datamodule_utils.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def infer_num_samples(
    limit_batches: Union[float, int, str, None], num_samples_in_dataset: int, global_batch_size: int, stage: str
):
    """Infers the number of samples based on the limit_batches parameter, the length of the dataset, and the global batch size.

    Args:
        limit_batches (Union[float, int, str, None]): The limit on the number of batches. Can be a float
            between 0 and 1, an integer, a string, or None. If None, defaults to 1.0.
        num_samples_in_dataset (int): The number of samples in the dataset.
        global_batch_size (int): The global batch size.
        stage (str): The stage of the training.

    Returns:
        int: The number of samples from the limit.

    Raises:
        ValueError: If the limited number of samples is less than the global batch size, or if the
            limit_batches parameter is invalid.

    If limit_batches is a float between 0 and 1, the number of samples is inferred as a fraction of the number of samples
    in the dataset. If limit_batches is an integer greater than or equal to 1, the number of limited samples is inferred
    as the product of limit_batches and global batch size. If limit_batches is None, it defaults to 1.0, indicating that
    all dataset samples should be used.
    """
    limit_batches = 1.0 if limit_batches is None else limit_batches  # validation data does not require upsampling
    if 0 < limit_batches <= 1.0 and isinstance(limit_batches, float):
        num_limited_samples = int(num_samples_in_dataset * limit_batches)
        if num_limited_samples < global_batch_size:
            raise ValueError(
                "The limited number of %s samples %s is less than the global batch size %s"
                % (stage, num_limited_samples, global_batch_size)
            )
    elif limit_batches >= 1 and isinstance(limit_batches, int):
        num_limited_samples = int(limit_batches * global_batch_size)
    else:
        raise ValueError("Invalid choice of limit_%s_batches size: %s" % (stage, limit_batches))

    return num_limited_samples

parse_kwargs_to_arglist(kwargs)

将关键字参数字典转换为命令行参数列表。

参数

名称 类型 描述 默认值
kwargs Dict[str, Any]

一个字典,其中键是参数名称,值是参数值。

必需

返回

类型 描述
List[str]

一个字符串列表,其中每个字符串都是 '--argument-name value' 格式的命令行参数。

源代码位于 bionemo/llm/utils/datamodule_utils.py
42
43
44
45
46
47
48
49
50
51
52
53
54
def parse_kwargs_to_arglist(kwargs: Dict[str, Any]) -> List[str]:
    """Converts a dictionary of keyword arguments into a list of command-line arguments.

    Args:
        kwargs (Dict[str, Any]): A dictionary where keys are argument names and values are argument values.

    Returns:
        A list of strings, where each string is a command-line argument in the format '--argument-name value'.
    """
    arglist = []
    for k, v in kwargs.items():
        arglist.extend([f"--{k.replace('_', '-')}", str(v)])
    return arglist