跳到内容

重采样器

PRNGResampleDataset

基类:Dataset[T_co]

一个线程安全的数据集混洗器,它使用伪随机数生成器 (PRNG) 来混洗数据集。

PRNGResampleDataset 使用伪随机数生成器 (PRNG) 混洗给定的数据集。这允许通过控制随机种子进行可重复的混洗,而无需将索引列表存储在内存中。它的工作原理是生成随机索引,假设请求函数按顺序请求它们。虽然支持随机查找,但随机查找将涉及重新计算状态,这很慢,并且如果上次请求的索引大于或等于此请求的索引,则涉及从 0 线性前进。这应该与顺序的 megatron 采样器配合良好。它通过不生成这些数字来处理多个 worker 可能发生的跳过查找。

推荐使用 bionemo.core.data.multi_epoch_dataset.MultiEpochDatasetResampler

此类对底层数据集执行有放回采样。建议改用 bionemo.core.data.multi_epoch_dataset.MultiEpochDatasetResampler 提供的基于 epoch 的采样,这确保每个样本在每个 epoch 中只被看到一次。此数据集适用于数据集太大,以至于混洗索引列表无法放入内存且不需要详尽采样的情况。

源代码位于 bionemo/core/data/resamplers.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class PRNGResampleDataset(Dataset[T_co]):
    """A thread-safe dataset shuffler that uses a pseudo-random number generator (PRNG) to shuffle the dataset.

    PRNGResampleDataset shuffles a given dataset using a pseudo-random number generator (PRNG). This allows for
    reproducible shuffling by controlling the random seed, while not ever storing the list of indices in memory. It
    works by generating random indices assuming that the requesting function asks for them sequentially. Although random
    lookups are supported, random lookups will involve recomputing state which is slow, and involves linearly advancing
    from 0 if the last requested index was greater than or equal to this requested index. This should work well with the
    megatron sampler which is sequential. It handles skipped lookups as will happen with multiple workers by not
    generating those numbers.

    !!! warning "Prefer bionemo.core.data.multi_epoch_dataset.MultiEpochDatasetResampler"

        This class performs sampling with replacement of an underlying dataset. It is recommended to use the epoch-based
        sampling provided by `bionemo.core.data.multi_epoch_dataset.MultiEpochDatasetResampler` instead, which ensures
        that each sample is seen exactly once per epoch. This dataset is useful for cases where the dataset is too large
        for the shuffled list of indices to fit in memory and exhaustive sampling is not required.
    """

    def __init__(self, dataset: Dataset[T_co], seed: int = 42, num_samples: Optional[int] = None):
        """Initializes the PRNGResampleDataset.

        Args:
            dataset: The dataset to be shuffled.
            seed: The seed value for the PRNG. Default is 42.
            num_samples: The number of samples to draw from the dataset.
                If None, the length of the dataset is used. Default is None.
        """
        self.initial_seed = seed
        self.rng = random.Random(seed)
        self.dataset_len = len(dataset)  # type: ignore
        self.num_samples = num_samples if num_samples is not None else len(dataset)
        self.dataset = dataset
        # Store the last accessed index. On this first pass this is initialized to infinity, which will trigger a reset since
        #  index - inf < 0 for all values of index. This will lead to `self.advance_state(index)` being called which will advance
        #  the state to the correct starting index. The last_index will be then be replaced by `index` in that case and the algorithm
        #  will proceed normally.
        self.last_index: Union[int, math.inf] = math.inf
        self.last_rand_index: Optional[int] = None

    def rand_idx(self) -> int:
        """Generates a random index within the range of the dataset size."""
        return self.rng.randint(0, self.dataset_len - 1)

    def advance_state(self, num_to_advance: int):
        """Advances the PRNG state by generating n_to_advance random indices.

        Args:
            num_to_advance: The number of random state steps to advance.
        """
        for _ in range(num_to_advance):
            self.rand_idx()

    def __getitem__(self, index: int) -> T_co:
        """Returns the item from the dataset at the specified index.

        Args:
            index: The index of the item to retrieve.

        Returns:
            The item from the dataset at the specified index.

        Note:
            If the requested index is before the last accessed index, the PRNG state is reset to the initial seed
            and advanced to the correct state. This is less efficient than advancing forward.
        """
        idx_diff = index - self.last_index
        if idx_diff < 0:
            # We need to go backwards (or it is the first call), which involves resetting to the initial seed and
            #   then advancing to just before the correct index, which is accomplished with `range(index)`.
            self.rng = random.Random(self.initial_seed)
            self.advance_state(index)
        elif idx_diff == 0:
            # If the index is the same as the last index, we can just return the last random index that was generated.
            #  no state needs to be updated in this case so just return.
            return self.dataset[self.last_rand_index]
        else:
            # We need to advance however many steps were skipped since the last call. Since i+1 - i = 1, we need to advance
            #  by `idx_diff - 1` to accomodate for skipped indices.
            self.advance_state(idx_diff - 1)
        self.last_index = index
        self.last_rand_index = (
            self.rand_idx()
        )  # store the last index called incase the user wants to requrest this index again.
        return self.dataset[self.last_rand_index]  # Advances state by 1

    def __len__(self) -> int:
        """Returns the total number of samples in the dataset."""
        return self.num_samples

__getitem__(index)

返回数据集中指定索引处的项。

参数

名称 类型 描述 默认值
index int

要检索的项的索引。

必需

返回

类型 描述
T_co

数据集中指定索引处的项。

注意

如果请求的索引在上次访问的索引之前,PRNG 状态将重置为初始种子并前进到正确的状态。这比向前前进效率低。

源代码位于 bionemo/core/data/resamplers.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def __getitem__(self, index: int) -> T_co:
    """Returns the item from the dataset at the specified index.

    Args:
        index: The index of the item to retrieve.

    Returns:
        The item from the dataset at the specified index.

    Note:
        If the requested index is before the last accessed index, the PRNG state is reset to the initial seed
        and advanced to the correct state. This is less efficient than advancing forward.
    """
    idx_diff = index - self.last_index
    if idx_diff < 0:
        # We need to go backwards (or it is the first call), which involves resetting to the initial seed and
        #   then advancing to just before the correct index, which is accomplished with `range(index)`.
        self.rng = random.Random(self.initial_seed)
        self.advance_state(index)
    elif idx_diff == 0:
        # If the index is the same as the last index, we can just return the last random index that was generated.
        #  no state needs to be updated in this case so just return.
        return self.dataset[self.last_rand_index]
    else:
        # We need to advance however many steps were skipped since the last call. Since i+1 - i = 1, we need to advance
        #  by `idx_diff - 1` to accomodate for skipped indices.
        self.advance_state(idx_diff - 1)
    self.last_index = index
    self.last_rand_index = (
        self.rand_idx()
    )  # store the last index called incase the user wants to requrest this index again.
    return self.dataset[self.last_rand_index]  # Advances state by 1

__init__(dataset, seed=42, num_samples=None)

初始化 PRNGResampleDataset。

参数

名称 类型 描述 默认值
dataset Dataset[T_co]

要混洗的数据集。

必需
seed int

PRNG 的种子值。默认为 42。

42
num_samples Optional[int]

要从数据集中抽取的样本数。如果为 None,则使用数据集的长度。默认为 None。

None
源代码位于 bionemo/core/data/resamplers.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def __init__(self, dataset: Dataset[T_co], seed: int = 42, num_samples: Optional[int] = None):
    """Initializes the PRNGResampleDataset.

    Args:
        dataset: The dataset to be shuffled.
        seed: The seed value for the PRNG. Default is 42.
        num_samples: The number of samples to draw from the dataset.
            If None, the length of the dataset is used. Default is None.
    """
    self.initial_seed = seed
    self.rng = random.Random(seed)
    self.dataset_len = len(dataset)  # type: ignore
    self.num_samples = num_samples if num_samples is not None else len(dataset)
    self.dataset = dataset
    # Store the last accessed index. On this first pass this is initialized to infinity, which will trigger a reset since
    #  index - inf < 0 for all values of index. This will lead to `self.advance_state(index)` being called which will advance
    #  the state to the correct starting index. The last_index will be then be replaced by `index` in that case and the algorithm
    #  will proceed normally.
    self.last_index: Union[int, math.inf] = math.inf
    self.last_rand_index: Optional[int] = None

__len__()

返回数据集中的样本总数。

源代码位于 bionemo/core/data/resamplers.py
115
116
117
def __len__(self) -> int:
    """Returns the total number of samples in the dataset."""
    return self.num_samples

advance_state(num_to_advance)

通过生成 n_to_advance 个随机索引来前进 PRNG 状态。

参数

名称 类型 描述 默认值
num_to_advance int

要前进的随机状态步数。

必需
源代码位于 bionemo/core/data/resamplers.py
73
74
75
76
77
78
79
80
def advance_state(self, num_to_advance: int):
    """Advances the PRNG state by generating n_to_advance random indices.

    Args:
        num_to_advance: The number of random state steps to advance.
    """
    for _ in range(num_to_advance):
        self.rand_idx()

rand_idx()

在数据集大小范围内生成一个随机索引。

源代码位于 bionemo/core/data/resamplers.py
69
70
71
def rand_idx(self) -> int:
    """Generates a random index within the range of the dataset size."""
    return self.rng.randint(0, self.dataset_len - 1)