跳到内容

实用工具

pickles_to_tars(dir_input, input_prefix_subset, input_suffix, dir_output, output_prefix, func_output_data=lambda prefix, suffix_to_data: {'__key__': prefix, None: suffix_to_data}, min_num_shards=None)

将目录中 pickle 文件的一个子集转换为 Webdataset tar 文件。

样本 0 的输入路径和名称模式:f"{dir_input}/{input_prefix_subset[0]}.{input_suffix[0]}" f"{dir_input}/{input_prefix_subset[0]}.{input_suffix[1]}" 样本 1 的输入路径和名称模式:f"{dir_input}/{input_prefix_subset[1]}.{input_suffix[0]}" f"{dir_input}/{input_prefix_subset[1]}.{input_suffix[1]}" ... 输出路径和名称模式:f"{dir_output}/{output_prefix}-%06d.tar"。

webdataset tar 存档由字典指定:{ "key" : sample_filename_preifx, sample_filename_suffix_1 : data_1, sample_filename_suffix_2 : data_2, ... },因此解析 tar 存档等同于读取 {sample_filename_preifx}.{sample_filename_suffix_1} 等。

在此,每个样本数据从 input_prefix_subset 的一个元素获取其名称前缀,并从列表 input_suffix 获取其名称后缀。根据 webdataset 文件格式规范,sample_filename_preifx 不能包含点 '.',因此此函数通过在 input_prefix_subset 的元素上调用 .replace(".", "-") 为用户删除它

参数

名称 类型 描述 默认值
dir_input str

输入目录

必需
input_prefix_subset List[str]

pickle 文件前缀的输入子集

必需
input_suffix Union[str, Iterable[str]]

输入 pickle 文件名后缀,每个后缀对应一种数据对象类型,适用于所有样本

必需
dir_output str

输出目录

必需
output_prefix str

输出 tar 文件名前缀

必需
func_output_data Callable[[str, Dict[str, Any]], Dict[str, Any]]

将名称前缀、名称后缀和数据对象映射到 webdataset tar 存档字典的函数。有关存档文件格式规范,请参阅 webdataset github 存储库。

lambda prefix, suffix_to_data: {'__key__': prefix, None: suffix_to_data}
min_num_shards

创建至少此数量的 tar 文件。WebDataset 在多节点 lightening + DDP 设置中读取少量 tar 文件时存在错误,因此可以使用此选项来保证 tar 文件计数

None
bionemo/webdatamodule/utils.py 中的源代码
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def pickles_to_tars(
    dir_input: str,
    input_prefix_subset: List[str],
    input_suffix: Union[str, Iterable[str]],
    dir_output: str,
    output_prefix: str,
    func_output_data: Callable[[str, Dict[str, Any]], Dict[str, Any]] = lambda prefix, suffix_to_data: {
        "__key__": prefix,
        **suffix_to_data,
    },
    min_num_shards: Optional[int] = None,
) -> None:
    """Convert a subset of pickle files from a directory to Webdataset tar files.

    Input path and name pattern for sample 0:
    f"{dir_input}/{input_prefix_subset[0]}.{input_suffix[0]}"
    f"{dir_input}/{input_prefix_subset[0]}.{input_suffix[1]}"
    Input path and name pattern for sample 1:
    f"{dir_input}/{input_prefix_subset[1]}.{input_suffix[0]}"
    f"{dir_input}/{input_prefix_subset[1]}.{input_suffix[1]}"
    ...
    Output path and name pattern:
    f"{dir_output}/{output_prefix}-%06d.tar".

    The webdataset tar archive is specified by the dictionary:
    {
        "__key__" : sample_filename_preifx,
        sample_filename_suffix_1 : data_1,
        sample_filename_suffix_2 : data_2,
        ...
    }
    so that parsing the tar archive is equivalent of reading
    {sample_filename_preifx}.{sample_filename_suffix_1} etc.

    Here, each sample data get its name prefix from one element of
    `input_prefix_subset` and its name suffixes from the list `input_suffix`.
    Per the webdataset file format specification, the `sample_filename_preifx`
    can't contain dots '.' so this function removes it for the user by calling
    .replace(".", "-") on the elements of `input_prefix_subset`

    Args:
        dir_input: Input directory
        input_prefix_subset: Input subset of pickle files' prefix
        input_suffix: Input pickle file name
            suffixes, each for one type of data object, for all the samples
        dir_output: Output directory
        output_prefix: Output tar file name prefix
        func_output_data: function that maps the name prefix, name suffix and
            data object to a webdataset tar archive dictionary. Refer to the webdataset
            github repo for the archive file format specification.
        min_num_shards : create at least this number of tar files.
            WebDataset has bugs when reading small number of tar files in a
            multi-node lightening + DDP setting so this option can be used to
            guarantee the tar file counts
    """
    if not isinstance(input_suffix, get_args(Union[str, Iterable])):
        raise TypeError("input_suffix can only be str or Iterable[str]")
    os.makedirs(dir_output, exist_ok=True)
    wd_subset_pattern = os.path.join(dir_output, f"{output_prefix}-%06d.tar")
    n_samples_per_shard_max = 100000
    if min_num_shards is not None:
        if min_num_shards <= 0:
            raise ValueError(f"Invalid min_num_shards = {min_num_shards} <= 0")
        n_samples_per_shard_max = len(input_prefix_subset) // min_num_shards
    with wds.ShardWriter(
        wd_subset_pattern,
        encoder=False,
        maxcount=n_samples_per_shard_max,
        compress=False,
        mode=0o777,
    ) as sink:
        for name in input_prefix_subset:
            try:
                if isinstance(input_suffix, str):
                    suffix_to_data = {
                        input_suffix: pickle.dumps(
                            pickle.loads((Path(dir_input) / f"{name}.{input_suffix}").read_bytes())
                        )
                    }
                else:
                    suffix_to_data = {
                        suffix: pickle.dumps(pickle.loads((Path(dir_input) / f"{name}.{suffix}").read_bytes()))
                        for suffix in input_suffix
                    }
                # the prefix name shouldn't contain any "." per webdataset's
                # specification
                sample = func_output_data(name.replace(".", "-"), suffix_to_data)
                sink.write(sample)
            except ModuleNotFoundError as e:
                raise RuntimeError(
                    "Can't process pickle file due to\
                                   missing dependencies"
                ) from e
            except Exception as e:
                raise RuntimeError(f"Failed to write {name} into tar files.") from e