bionemo-webdatamodule
要安装,请执行以下命令
pip install -e .
要运行单元测试,请执行
pytest -v .
WebDataModule
class WebDataModule(L.LightningDataModule)
用于使用 webdataset tar 文件的 LightningDataModule。
WebDataModule
是一个 LightningDataModule
,用于使用 webdataset tar 文件来设置 PyTorch 数据集和数据加载器。此数据模块接受一个字典作为输入:Split -> tar 文件目录和各种 webdataset 配置设置。在其 setup() 函数中,它创建 webdataset 对象,链接输入 pipeline_wds
工作流程。在其 train/val/test_dataloader() 函数中,它创建 WebLoader 对象,链接 pipeline_prebatch_wld
工作流程。
示例:
-
使用 webdataset tar 文件的输入目录创建数据模块。根据调用的下游 Lightning.Trainer 方法,例如
Trainer.fit()
、Trainer.validate()
、Trainer.test()
或Trainer.predict()
,只需在数据模块的各种输入选项中指定 train、val 和 test 分支的子集 -
Trainer.fit()
需要train
和val
分支 Trainer.validate()
需要val
分支Trainer.test()
需要test
分支Trainer.predict()
需要test
分支
以下是为 Trainer.fit()
构建数据模块的示例
>>> from bionemo.webdatamodule.datamodule import Split, WebDataModule
>>>
>>> tar_file_prefix = "shards"
>>>
>>> dirs_of_tar_files = {
>>> Split.train: "/path/to/train/split/tars",
>>> Split.val: "/path/to/val/split/tars",
>>> }
>>>
>>> n_samples {
>>> Split.train: 1000,
>>> Split.val: 100,
>>> }
>>>
>>> # this is the string to retrieve the corresponding data object from the
>>> # webdataset file (see
>>> # https://github.com/webdataset/webdataset?tab=readme-ov-file#the-webdataset-format
>>> # for details)
>>> suffix_keys_wds = "tensor.pyd"
>>>
>>> seed = 27193781
>>>
>>> # Specify the routines to process the samples in the WebDataset object.
>>> # The routine is a generator of an Iterable of generators that are chained
>>> # together by nested function calling. The following is equivalent of
>>> # defining a overall generator of `shuffle(untuple(...))` which
>>> # untuples the samples and shuffles them. See webdataset's Documentation
>>> # for details.
>>> # NOTE: the `untuple` is almost always necessary due to the webdataset's
>>> # file parsing rule.
>>>
>>> untuple = lambda source : (sample for (sample,) in source)
>>>
>>> from webdatast import shuffle
>>> pipeline_wds = {
>>> Split.train : [untuple, shuffle(n_samples[Split.train],
>>> rng=random.Random(seed_rng_shfl))],
>>> Split.val: untuple
>>> }
>>>
>>> # Similarly the user can optionally define the processing routine on the
>>> # WebLoader (the dataloader of webdataset).
>>> # NOTE: these routines by default take unbatched sample as input so the
>>> # user can customize their batching routines here
>>>
>>> batch = batched(local_batch_size, collation_fn=lambda
list_samples : torch.vstack(list_samples))
>>> pipeline_prebatch_wld = {
Split.train: [shuffle(n_samples[Split.train],
rng=random.Random(seed_rng_shfl)), batch],
Split.val : batch,
Split.test : batch
}
>>>
>>> # the user can optionally specify the kwargs for WebDataset and
>>> # WebLoader
>>>
>>> kwargs_wds = {
>>> split : {'shardshuffle' : split == Split.train,
>>> 'nodesplitter' : wds.split_by_node,
>>> 'seed' : seed_rng_shfl}
>>> for split in Split
>>> }
>>>
>>> kwargs_wld = {
>>> split : {"num_workers": 2} for split in Split
>>> }
>>>
>>> invoke_wds = {
>>> split: [("with_epoch", {"nbatches" : 5})] for split in Split
>>> }
>>>
>>> invoke_wld = {
>>> split: [("with_epoch", {"nbatches" : 5}] for split in Split
>>> }
>>>
>>> # construct the data module
>>> data_module = WebDataModule(suffix_keys_wds,
dirs_of_tar_files,
prefix_tars_wds=tar_file_prefix,
pipeline_wds=pipeline_wds,
pipeline_prebatch_wld=pipeline_prebatch_wld,
kwargs_wds=kwargs_wds,
kwargs_wld=kwargs_wld,
invoke_wds=invoke_wds,
invoke_wld=invoke_wld,
)
__init__
def __init__(
suffix_keys_wds: Union[str, Iterable[str]],
dirs_tars_wds: Dict[Split, str],
prefix_tars_wds: str = "wdshards",
pipeline_wds: Optional[Dict[Split, Union[Iterable[Iterable[Any]],
Iterable[Any]]]] = None,
pipeline_prebatch_wld: Optional[Dict[Split, Union[Iterable[Iterable[Any]],
Iterable[Any]]]] = None,
kwargs_wds: Optional[Dict[Split, Dict[str, Any]]] = None,
kwargs_wld: Optional[Dict[Split, Dict[str, Any]]] = None,
invoke_wds: Optional[Dict[Split, List[Tuple[str, Dict[str, Any]]]]] = None,
invoke_wld: Optional[Dict[Split, List[Tuple[str, Dict[str,
Any]]]]] = None)
构造函数。
参数:
suffix_keys_wds
- 一组键,每个键对应于 webdataset tar 文件字典中的一个数据对象。这些键的数据对象将被提取并元组化,用于 tar 文件中的每个样本dirs_tars_wds
- 输入字典:Split -> tar 文件目录,其中包含每个分支的 webdataset tar 文件 Kwargsprefix_tars_wds
- 输入 webdataset tar 文件名的名称前缀。输入 tar 文件通过 "{dirs_tars_wds[split]}/{prefix_tars_wds}-*.tar" 进行 globpipeline_wds
- webdatast 可组合函数的字典,即,将迭代器映射到另一个迭代器的函子,该迭代器转换从数据集对象产生的数据样本,用于不同的分支,或用于此类迭代器序列的可迭代对象。例如,这可以用于在 worker 中转换样本,然后再将其发送到数据加载器的主进程pipeline_prebatch_wld
- webloader 可组合函数的字典,即,将迭代器映射到另一个迭代器的函子,该迭代器转换从 WebLoader 对象产生的数据样本,用于不同的分支,或用于此类迭代器序列的可迭代对象。例如,这可以用于批处理样本。注意:这在从 WebLoader 产生批处理之前应用kwargs_wds
- WebDataset.init() 的 kwargs kwargs_wld:WebLoader.init() 的 kwargs,例如,每个分支的 num_workersinvoke_wds
- 要在 WebDataset 构建时调用的 WebDataset 方法的字典。这些方法必须返回 WebDataset 对象本身。示例包括 .with_length() 和 .with_epoch()。这些方法将在返回 WebDataset 对象结束时应用,即,在应用 pipline_wds 之后。元组的内部列表的第一个元素是方法名称,第二个元素是相应方法的 kwargs。invoke_wld
- 要在 WebLoader 构建时调用的 WebLoader 方法的字典。这些方法必须返回 WebLoader 对象本身。示例包括 .with_length() 和 .with_epoch()。这些方法将在返回 WebLoader 对象结束时应用,即,在应用 pipelin_prebatch_wld 之后。元组的内部列表的第一个元素是方法名称,第二个元素是相应方法的 kwargs。
prepare_data
def prepare_data() -> None
这仅由主进程通过 Lightning 工作流程调用。
不要依赖此数据模块对象在此处的状态更新,因为无法将状态更新通信到其他子进程。是一个空操作。
setup
def setup(stage: str) -> None
这在多节点训练会话中的所有 Lightning 管理节点上调用。
参数:
stage
- "fit"、"test" 或 "predict"
train_dataloader
def train_dataloader() -> wds.WebLoader
用于训练数据的 Webdataset。
val_dataloader
def val_dataloader() -> wds.WebLoader
用于验证数据的 Webdataset。
test_dataloader
def test_dataloader() -> wds.WebLoader
用于测试数据的 Webdataset。
predict_dataloader
def predict_dataloader() -> wds.WebLoader
:func:test_dataloader
的别名。
PickledDataWDS 对象
class PickledDataWDS(WebDataModule)
用于将 pickle 数据处理为 webdataset tar 文件的 LightningDataModule。
PickledDataWDS
是一个 LightningDataModule,用于将 pickle 数据处理为 webdataset tar 文件,并设置数据集和数据加载器。这从其父模块 WebDataModule
继承 webdataset 设置。此数据模块接受一个 pickle 数据文件目录、用于 train/val/test 分支的数据文件名前缀、数据文件名后缀,并通过 glob 特定 pickle 数据文件 {dir_pickles}/{name_subset[split]}.{suffix_pickles}
来准备 webdataset tar 文件,并使用字典结构输出到 webdataset tar 文件:注意:这假设每个样本仅处理一个 pickle 文件。在其 setup() 函数中,它创建 webdataset 对象,链接输入 pipeline_wds
工作流程。在其 train/val/test_dataloader() 函数中,它创建 WebLoader 对象,链接 pipeline_prebatch_wld
工作流程。
{"__key__" : name.replace(".", "-"),
suffix_pickles : pickled.dumps(data) }
示例:
- 使用 pickle 文件目录和用于
Lightning.Trainer.fit()
的不同分支的文件名前缀创建数据模块
>>> from bionemo.core.data.datamodule import Split, PickledDataWDS
>>> dir_pickles = "/path/to/my/pickles/dir"
>>> # the following will use `sample1.mydata.pt` and `sample2.mydata.pt` as the
>>> # training dataset and `sample4.mydata.pt` and `sample5.mydata.pt` as the
>>> # validation dataset
>>> suffix_pickles = "mydata.pt"
>>> names_subset = {
>>> Split.train: [sample1, sample2],
>>> Split.val: [sample4, sample5],
>>> }
>>> # the following setting will attempt to create at least 5 tar files in
>>> # `/path/to/output/tars/dir/myshards-00000{0-5}.tar`
>>> n_tars_wds = 5
>>> prefix_tars_wds = "myshards"
>>> output_dir_tar_files = {
Split.train : "/path/to/output/tars/dir-train",
Split.val : "/path/to/output/tars/dir-val",
Split.test : "/path/to/output/tars/dir-test",
}
>>> # user can optionally customize the data processing routines and kwargs used
>>> # in the WebDataset and WebLoader (see the examples in `WebDataModule`)
>>> pipeline_wds = { Split.train: ... }
>>> pipeline_prebatch_wld = { Split.train: ... }
>>> kwargs_wds = { Split.train: ..., Split.val: ... }
>>> kwargs_wld = { Split.train: ..., Split.val: ... }
>>> invoke_wds = { Split.train: ..., Split.val: ... }
>>> invoke_wld = { Split.train: ..., Split.val: ... }
>>> # create the data module
>>> data_module = PickledDataWDS(
>>> dir_pickles,
>>> names_subset,
>>> suffix_pickles, # `WebDataModule` args
>>> output_dir_tar_files, # `WebDataModule` args
>>> n_tars_wds=n_tars_wds,
>>> prefix_tars_wds=prefix_tars_wds, # `WebDataModule` kwargs
>>> pipeline_wds=pipeline_wds, # `WebDataModule` kwargs
>>> pipeline_prebatch_wld=pipelines_wdl_batch, # `WebDataModule` kwargs
>>> kwargs_wds=kwargs_wds, # `WebDataModule` kwargs
>>> kwargs_wld=kwargs_wld, # `WebDataModule` kwargs
>>> invoke_wds=invoke_wds, # `WebDataModule` kwargs
>>> invoke_wld=invoke_wld, # `WebDataModule` kwargs
>>> )
__init__
def __init__(dir_pickles: str,
names_subset: Dict[Split, List[str]],
*args,
n_tars_wds: Optional[int] = None,
**kwargs) -> None
构造函数。
参数:
dir_pickles
- pickle 数据文件的输入目录names_subset
- 要在每个分支的数据集和数据加载器中加载的数据样本的文件名前缀列表*args
- 传递给父 WebDataModule 的参数n_tars_wds
- 尝试创建至少此数量的 webdataset 分片**kwargs
- 传递给父 WebDataModule 的参数
prepare_data
def prepare_data() -> None
这仅由主进程通过 Lightning 工作流程调用。
不要依赖此数据模块对象在此处的状态更新,因为无法将状态更新通信到其他子进程。嵌套的 pickles_to_tars
函数遍历不同分支中的数据名称前缀,读取相应的 pickle 文件,并输出具有字典结构的 webdataset tar 存档:{"key" : name.replace(".", "-"), suffix_pickles : pickled.dumps(data) }。