基类:Dataset[T_co]
一个线程安全的数据集混洗器,它使用伪随机数生成器 (PRNG) 来混洗数据集。
PRNGResampleDataset 使用伪随机数生成器 (PRNG) 混洗给定的数据集。这允许通过控制随机种子进行可重复的混洗,而无需将索引列表存储在内存中。它的工作原理是生成随机索引,假设请求函数按顺序请求它们。虽然支持随机查找,但随机查找将涉及重新计算状态,这很慢,并且如果上次请求的索引大于或等于此请求的索引,则涉及从 0 线性前进。这应该与顺序的 megatron 采样器配合良好。它通过不生成这些数字来处理多个 worker 可能发生的跳过查找。
推荐使用 bionemo.core.data.multi_epoch_dataset.MultiEpochDatasetResampler
此类对底层数据集执行有放回采样。建议改用 bionemo.core.data.multi_epoch_dataset.MultiEpochDatasetResampler
提供的基于 epoch 的采样,这确保每个样本在每个 epoch 中只被看到一次。此数据集适用于数据集太大,以至于混洗索引列表无法放入内存且不需要详尽采样的情况。
源代码位于 bionemo/core/data/resamplers.py
中
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117 | class PRNGResampleDataset(Dataset[T_co]):
"""A thread-safe dataset shuffler that uses a pseudo-random number generator (PRNG) to shuffle the dataset.
PRNGResampleDataset shuffles a given dataset using a pseudo-random number generator (PRNG). This allows for
reproducible shuffling by controlling the random seed, while not ever storing the list of indices in memory. It
works by generating random indices assuming that the requesting function asks for them sequentially. Although random
lookups are supported, random lookups will involve recomputing state which is slow, and involves linearly advancing
from 0 if the last requested index was greater than or equal to this requested index. This should work well with the
megatron sampler which is sequential. It handles skipped lookups as will happen with multiple workers by not
generating those numbers.
!!! warning "Prefer bionemo.core.data.multi_epoch_dataset.MultiEpochDatasetResampler"
This class performs sampling with replacement of an underlying dataset. It is recommended to use the epoch-based
sampling provided by `bionemo.core.data.multi_epoch_dataset.MultiEpochDatasetResampler` instead, which ensures
that each sample is seen exactly once per epoch. This dataset is useful for cases where the dataset is too large
for the shuffled list of indices to fit in memory and exhaustive sampling is not required.
"""
def __init__(self, dataset: Dataset[T_co], seed: int = 42, num_samples: Optional[int] = None):
"""Initializes the PRNGResampleDataset.
Args:
dataset: The dataset to be shuffled.
seed: The seed value for the PRNG. Default is 42.
num_samples: The number of samples to draw from the dataset.
If None, the length of the dataset is used. Default is None.
"""
self.initial_seed = seed
self.rng = random.Random(seed)
self.dataset_len = len(dataset) # type: ignore
self.num_samples = num_samples if num_samples is not None else len(dataset)
self.dataset = dataset
# Store the last accessed index. On this first pass this is initialized to infinity, which will trigger a reset since
# index - inf < 0 for all values of index. This will lead to `self.advance_state(index)` being called which will advance
# the state to the correct starting index. The last_index will be then be replaced by `index` in that case and the algorithm
# will proceed normally.
self.last_index: Union[int, math.inf] = math.inf
self.last_rand_index: Optional[int] = None
def rand_idx(self) -> int:
"""Generates a random index within the range of the dataset size."""
return self.rng.randint(0, self.dataset_len - 1)
def advance_state(self, num_to_advance: int):
"""Advances the PRNG state by generating n_to_advance random indices.
Args:
num_to_advance: The number of random state steps to advance.
"""
for _ in range(num_to_advance):
self.rand_idx()
def __getitem__(self, index: int) -> T_co:
"""Returns the item from the dataset at the specified index.
Args:
index: The index of the item to retrieve.
Returns:
The item from the dataset at the specified index.
Note:
If the requested index is before the last accessed index, the PRNG state is reset to the initial seed
and advanced to the correct state. This is less efficient than advancing forward.
"""
idx_diff = index - self.last_index
if idx_diff < 0:
# We need to go backwards (or it is the first call), which involves resetting to the initial seed and
# then advancing to just before the correct index, which is accomplished with `range(index)`.
self.rng = random.Random(self.initial_seed)
self.advance_state(index)
elif idx_diff == 0:
# If the index is the same as the last index, we can just return the last random index that was generated.
# no state needs to be updated in this case so just return.
return self.dataset[self.last_rand_index]
else:
# We need to advance however many steps were skipped since the last call. Since i+1 - i = 1, we need to advance
# by `idx_diff - 1` to accomodate for skipped indices.
self.advance_state(idx_diff - 1)
self.last_index = index
self.last_rand_index = (
self.rand_idx()
) # store the last index called incase the user wants to requrest this index again.
return self.dataset[self.last_rand_index] # Advances state by 1
def __len__(self) -> int:
"""Returns the total number of samples in the dataset."""
return self.num_samples
|
__getitem__(index)
返回数据集中指定索引处的项。
参数
名称 |
类型 |
描述 |
默认值 |
index
|
int
|
|
必需
|
返回
注意
如果请求的索引在上次访问的索引之前,PRNG 状态将重置为初始种子并前进到正确的状态。这比向前前进效率低。
源代码位于 bionemo/core/data/resamplers.py
中
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113 | def __getitem__(self, index: int) -> T_co:
"""Returns the item from the dataset at the specified index.
Args:
index: The index of the item to retrieve.
Returns:
The item from the dataset at the specified index.
Note:
If the requested index is before the last accessed index, the PRNG state is reset to the initial seed
and advanced to the correct state. This is less efficient than advancing forward.
"""
idx_diff = index - self.last_index
if idx_diff < 0:
# We need to go backwards (or it is the first call), which involves resetting to the initial seed and
# then advancing to just before the correct index, which is accomplished with `range(index)`.
self.rng = random.Random(self.initial_seed)
self.advance_state(index)
elif idx_diff == 0:
# If the index is the same as the last index, we can just return the last random index that was generated.
# no state needs to be updated in this case so just return.
return self.dataset[self.last_rand_index]
else:
# We need to advance however many steps were skipped since the last call. Since i+1 - i = 1, we need to advance
# by `idx_diff - 1` to accomodate for skipped indices.
self.advance_state(idx_diff - 1)
self.last_index = index
self.last_rand_index = (
self.rand_idx()
) # store the last index called incase the user wants to requrest this index again.
return self.dataset[self.last_rand_index] # Advances state by 1
|
__init__(dataset, seed=42, num_samples=None)
初始化 PRNGResampleDataset。
参数
名称 |
类型 |
描述 |
默认值 |
dataset
|
Dataset[T_co]
|
|
必需
|
seed
|
int
|
|
42
|
num_samples
|
Optional[int]
|
要从数据集中抽取的样本数。如果为 None,则使用数据集的长度。默认为 None。
|
None
|
源代码位于 bionemo/core/data/resamplers.py
中
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67 | def __init__(self, dataset: Dataset[T_co], seed: int = 42, num_samples: Optional[int] = None):
"""Initializes the PRNGResampleDataset.
Args:
dataset: The dataset to be shuffled.
seed: The seed value for the PRNG. Default is 42.
num_samples: The number of samples to draw from the dataset.
If None, the length of the dataset is used. Default is None.
"""
self.initial_seed = seed
self.rng = random.Random(seed)
self.dataset_len = len(dataset) # type: ignore
self.num_samples = num_samples if num_samples is not None else len(dataset)
self.dataset = dataset
# Store the last accessed index. On this first pass this is initialized to infinity, which will trigger a reset since
# index - inf < 0 for all values of index. This will lead to `self.advance_state(index)` being called which will advance
# the state to the correct starting index. The last_index will be then be replaced by `index` in that case and the algorithm
# will proceed normally.
self.last_index: Union[int, math.inf] = math.inf
self.last_rand_index: Optional[int] = None
|
__len__()
返回数据集中的样本总数。
源代码位于 bionemo/core/data/resamplers.py
中
| def __len__(self) -> int:
"""Returns the total number of samples in the dataset."""
return self.num_samples
|
advance_state(num_to_advance)
通过生成 n_to_advance 个随机索引来前进 PRNG 状态。
参数
名称 |
类型 |
描述 |
默认值 |
num_to_advance
|
int
|
|
必需
|
源代码位于 bionemo/core/data/resamplers.py
中
| def advance_state(self, num_to_advance: int):
"""Advances the PRNG state by generating n_to_advance random indices.
Args:
num_to_advance: The number of random state steps to advance.
"""
for _ in range(num_to_advance):
self.rand_idx()
|
rand_idx()
在数据集大小范围内生成一个随机索引。
源代码位于 bionemo/core/data/resamplers.py
中
| def rand_idx(self) -> int:
"""Generates a random index within the range of the dataset size."""
return self.rng.randint(0, self.dataset_len - 1)
|