跳到内容

bionemo-webdatamodule

要安装,请执行以下命令

pip install -e .

要运行单元测试,请执行

pytest -v .

WebDataModule

class WebDataModule(L.LightningDataModule)

用于使用 webdataset tar 文件的 LightningDataModule。

WebDataModule 是一个 LightningDataModule,用于使用 webdataset tar 文件来设置 PyTorch 数据集和数据加载器。此数据模块接受一个字典作为输入:Split -> tar 文件目录和各种 webdataset 配置设置。在其 setup() 函数中,它创建 webdataset 对象,链接输入 pipeline_wds 工作流程。在其 train/val/test_dataloader() 函数中,它创建 WebLoader 对象,链接 pipeline_prebatch_wld 工作流程。

示例:


  1. 使用 webdataset tar 文件的输入目录创建数据模块。根据调用的下游 Lightning.Trainer 方法,例如 Trainer.fit()Trainer.validate()Trainer.test()Trainer.predict(),只需在数据模块的各种输入选项中指定 train、val 和 test 分支的子集

  2. Trainer.fit() 需要 trainval 分支

  3. Trainer.validate() 需要 val 分支
  4. Trainer.test() 需要 test 分支
  5. Trainer.predict() 需要 test 分支

以下是为 Trainer.fit() 构建数据模块的示例

>>> from bionemo.webdatamodule.datamodule import Split, WebDataModule
>>>
>>> tar_file_prefix = "shards"
>>>
>>> dirs_of_tar_files = {
>>>     Split.train: "/path/to/train/split/tars",
>>>     Split.val: "/path/to/val/split/tars",
>>> }
>>>
>>> n_samples {
>>>     Split.train: 1000,
>>>     Split.val: 100,
>>> }
>>>
>>> # this is the string to retrieve the corresponding data object from the
>>> # webdataset file (see
>>> # https://github.com/webdataset/webdataset?tab=readme-ov-file#the-webdataset-format
>>> # for details)
>>> suffix_keys_wds = "tensor.pyd"
>>>
>>> seed = 27193781
>>>
>>> # Specify the routines to process the samples in the WebDataset object.
>>> # The routine is a generator of an Iterable of generators that are chained
>>> # together by nested function calling. The following is equivalent of
>>> # defining a overall generator of `shuffle(untuple(...))` which
>>> # untuples the samples and shuffles them. See webdataset's Documentation
>>> # for details.
>>> # NOTE: the `untuple` is almost always necessary due to the webdataset's
>>> # file parsing rule.
>>>
>>> untuple = lambda source : (sample for (sample,) in source)
>>>
>>> from webdatast import shuffle
>>> pipeline_wds = {
>>>     Split.train : [untuple, shuffle(n_samples[Split.train],
>>>                                     rng=random.Random(seed_rng_shfl))],
>>>     Split.val: untuple
>>> }
>>>
>>> # Similarly the user can optionally define the processing routine on the
>>> # WebLoader (the dataloader of webdataset).
>>> # NOTE: these routines by default take unbatched sample as input so the
>>> # user can customize their batching routines here
>>>
>>> batch = batched(local_batch_size, collation_fn=lambda
                    list_samples : torch.vstack(list_samples))
>>> pipeline_prebatch_wld = {
        Split.train: [shuffle(n_samples[Split.train],
                              rng=random.Random(seed_rng_shfl)), batch],
        Split.val : batch,
        Split.test : batch
    }
>>>
>>> # the user can optionally specify the kwargs for WebDataset and
>>> # WebLoader
>>>
>>> kwargs_wds = {
>>>     split : {'shardshuffle' : split == Split.train,
>>>              'nodesplitter' : wds.split_by_node,
>>>              'seed' : seed_rng_shfl}
>>>     for split in Split
>>>     }
>>>
>>> kwargs_wld = {
>>>     split : {"num_workers": 2} for split in Split
>>>     }
>>>
>>> invoke_wds = {
>>>     split: [("with_epoch", {"nbatches" : 5})] for split in Split
>>>     }
>>>
>>> invoke_wld = {
>>>     split: [("with_epoch", {"nbatches" : 5}] for split in Split
>>>     }
>>>
>>> # construct the data module
>>> data_module = WebDataModule(suffix_keys_wds,
                                dirs_of_tar_files,
                                prefix_tars_wds=tar_file_prefix,
                                pipeline_wds=pipeline_wds,
                                pipeline_prebatch_wld=pipeline_prebatch_wld,
                                kwargs_wds=kwargs_wds,
                                kwargs_wld=kwargs_wld,
                                invoke_wds=invoke_wds,
                                invoke_wld=invoke_wld,
                                )

__init__

def __init__(
    suffix_keys_wds: Union[str, Iterable[str]],
    dirs_tars_wds: Dict[Split, str],
    prefix_tars_wds: str = "wdshards",
    pipeline_wds: Optional[Dict[Split, Union[Iterable[Iterable[Any]],
                                             Iterable[Any]]]] = None,
    pipeline_prebatch_wld: Optional[Dict[Split, Union[Iterable[Iterable[Any]],
                                                      Iterable[Any]]]] = None,
    kwargs_wds: Optional[Dict[Split, Dict[str, Any]]] = None,
    kwargs_wld: Optional[Dict[Split, Dict[str, Any]]] = None,
    invoke_wds: Optional[Dict[Split, List[Tuple[str, Dict[str, Any]]]]] = None,
    invoke_wld: Optional[Dict[Split, List[Tuple[str, Dict[str,
                                                          Any]]]]] = None)

构造函数。

参数:

  • suffix_keys_wds - 一组键,每个键对应于 webdataset tar 文件字典中的一个数据对象。这些键的数据对象将被提取并元组化,用于 tar 文件中的每个样本
  • dirs_tars_wds - 输入字典:Split -> tar 文件目录,其中包含每个分支的 webdataset tar 文件 Kwargs
  • prefix_tars_wds - 输入 webdataset tar 文件名的名称前缀。输入 tar 文件通过 "{dirs_tars_wds[split]}/{prefix_tars_wds}-*.tar" 进行 glob
  • pipeline_wds - webdatast 可组合函数的字典,即,将迭代器映射到另一个迭代器的函子,该迭代器转换从数据集对象产生的数据样本,用于不同的分支,或用于此类迭代器序列的可迭代对象。例如,这可以用于在 worker 中转换样本,然后再将其发送到数据加载器的主进程
  • pipeline_prebatch_wld - webloader 可组合函数的字典,即,将迭代器映射到另一个迭代器的函子,该迭代器转换从 WebLoader 对象产生的数据样本,用于不同的分支,或用于此类迭代器序列的可迭代对象。例如,这可以用于批处理样本。注意:这在从 WebLoader 产生批处理之前应用
  • kwargs_wds - WebDataset.init() 的 kwargs kwargs_wld:WebLoader.init() 的 kwargs,例如,每个分支的 num_workers
  • invoke_wds - 要在 WebDataset 构建时调用的 WebDataset 方法的字典。这些方法必须返回 WebDataset 对象本身。示例包括 .with_length() 和 .with_epoch()。这些方法将在返回 WebDataset 对象结束时应用,即,在应用 pipline_wds 之后。元组的内部列表的第一个元素是方法名称,第二个元素是相应方法的 kwargs。
  • invoke_wld - 要在 WebLoader 构建时调用的 WebLoader 方法的字典。这些方法必须返回 WebLoader 对象本身。示例包括 .with_length() 和 .with_epoch()。这些方法将在返回 WebLoader 对象结束时应用,即,在应用 pipelin_prebatch_wld 之后。元组的内部列表的第一个元素是方法名称,第二个元素是相应方法的 kwargs。

prepare_data

def prepare_data() -> None

这仅由主进程通过 Lightning 工作流程调用。

不要依赖此数据模块对象在此处的状态更新,因为无法将状态更新通信到其他子进程。是一个空操作

setup

def setup(stage: str) -> None

这在多节点训练会话中的所有 Lightning 管理节点上调用。

参数:

  • stage - "fit"、"test" 或 "predict"

train_dataloader

def train_dataloader() -> wds.WebLoader

用于训练数据的 Webdataset。

val_dataloader

def val_dataloader() -> wds.WebLoader

用于验证数据的 Webdataset。

test_dataloader

def test_dataloader() -> wds.WebLoader

用于测试数据的 Webdataset。

predict_dataloader

def predict_dataloader() -> wds.WebLoader

:func:test_dataloader 的别名。

PickledDataWDS 对象

class PickledDataWDS(WebDataModule)

用于将 pickle 数据处理为 webdataset tar 文件的 LightningDataModule。

PickledDataWDS 是一个 LightningDataModule,用于将 pickle 数据处理为 webdataset tar 文件,并设置数据集和数据加载器。这从其父模块 WebDataModule 继承 webdataset 设置。此数据模块接受一个 pickle 数据文件目录、用于 train/val/test 分支的数据文件名前缀、数据文件名后缀,并通过 glob 特定 pickle 数据文件 {dir_pickles}/{name_subset[split]}.{suffix_pickles} 来准备 webdataset tar 文件,并使用字典结构输出到 webdataset tar 文件:注意:这假设每个样本仅处理一个 pickle 文件。在其 setup() 函数中,它创建 webdataset 对象,链接输入 pipeline_wds 工作流程。在其 train/val/test_dataloader() 函数中,它创建 WebLoader 对象,链接 pipeline_prebatch_wld 工作流程。

    {"__key__" : name.replace(".", "-"),
     suffix_pickles : pickled.dumps(data) }

示例:


  1. 使用 pickle 文件目录和用于 Lightning.Trainer.fit() 的不同分支的文件名前缀创建数据模块
>>> from bionemo.core.data.datamodule import Split, PickledDataWDS

>>> dir_pickles = "/path/to/my/pickles/dir"

>>> # the following will use `sample1.mydata.pt` and `sample2.mydata.pt` as the
>>> # training dataset and `sample4.mydata.pt` and `sample5.mydata.pt` as the
>>> # validation dataset

>>> suffix_pickles = "mydata.pt"

>>> names_subset = {
>>>     Split.train: [sample1, sample2],
>>>     Split.val: [sample4, sample5],
>>> }

>>> # the following setting will attempt to create at least 5 tar files in
>>> # `/path/to/output/tars/dir/myshards-00000{0-5}.tar`

>>> n_tars_wds = 5
>>> prefix_tars_wds = "myshards"
>>> output_dir_tar_files = {
        Split.train : "/path/to/output/tars/dir-train",
        Split.val : "/path/to/output/tars/dir-val",
        Split.test : "/path/to/output/tars/dir-test",
    }

>>> # user can optionally customize the data processing routines and kwargs used
>>> # in the WebDataset and WebLoader (see the examples in `WebDataModule`)

>>> pipeline_wds = { Split.train: ... }

>>> pipeline_prebatch_wld = { Split.train: ... }

>>> kwargs_wds = { Split.train: ..., Split.val: ... }

>>> kwargs_wld = { Split.train: ..., Split.val: ... }

>>> invoke_wds = { Split.train: ..., Split.val: ... }

>>> invoke_wld = { Split.train: ..., Split.val: ... }

>>> # create the data module
>>> data_module = PickledDataWDS(
>>>     dir_pickles,
>>>     names_subset,
>>>     suffix_pickles, # `WebDataModule` args
>>>     output_dir_tar_files, # `WebDataModule` args
>>>     n_tars_wds=n_tars_wds,
>>>     prefix_tars_wds=prefix_tars_wds, # `WebDataModule` kwargs
>>>     pipeline_wds=pipeline_wds, # `WebDataModule` kwargs
>>>     pipeline_prebatch_wld=pipelines_wdl_batch, # `WebDataModule` kwargs
>>>     kwargs_wds=kwargs_wds, # `WebDataModule` kwargs
>>>     kwargs_wld=kwargs_wld, # `WebDataModule` kwargs
>>>     invoke_wds=invoke_wds, # `WebDataModule` kwargs
>>>     invoke_wld=invoke_wld, # `WebDataModule` kwargs
>>> )

__init__

def __init__(dir_pickles: str,
             names_subset: Dict[Split, List[str]],
             *args,
             n_tars_wds: Optional[int] = None,
             **kwargs) -> None

构造函数。

参数:

  • dir_pickles - pickle 数据文件的输入目录
  • names_subset - 要在每个分支的数据集和数据加载器中加载的数据样本的文件名前缀列表
  • *args - 传递给父 WebDataModule 的参数
  • n_tars_wds - 尝试创建至少此数量的 webdataset 分片
  • **kwargs - 传递给父 WebDataModule 的参数

prepare_data

def prepare_data() -> None

这仅由主进程通过 Lightning 工作流程调用。

不要依赖此数据模块对象在此处的状态更新,因为无法将状态更新通信到其他子进程。嵌套的 pickles_to_tars 函数遍历不同分支中的数据名称前缀,读取相应的 pickle 文件,并输出具有字典结构的 webdataset tar 存档:{"key" : name.replace(".", "-"), suffix_pickles : pickled.dumps(data) }。