跳到内容

Permute

permute(index, length, seed)

以恒定的空间和时间复杂度索引到置换数组中。

此函数使用哈希函数将索引 i 置换到范围 [0, l) 中。 有关更多详细信息,请参阅 https://afnan.io/posts/2019-04-05-explaining-the-hashed-permutation/ 以及 Andrew Kensler 的原始算法“相关多重抖动采样”。

参数

名称 类型 描述 默认值
index int

要置换的索引。

必需
length int

置换索引的范围。

必需
seed int

置换种子。

必需

返回值

类型 描述
int

范围 (0, length) 内的置换索引。

源代码位于 bionemo/core/data/permute.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def permute(index: int, length: int, seed: int) -> int:
    """Index into a permuted array with constant space and time complexity.

    This function permutes an index `i` into a range `[0, l)` using a hash function. See
    https://afnan.io/posts/2019-04-05-explaining-the-hashed-permutation/ for more details and
    "Correlated Multi-Jittered Sampling" by Andrew Kensler for the original algorithm.

    Args:
        index: The index to permute.
        length: The range of the permuted index.
        seed: The permutation seed.

    Returns:
        The permuted index in range(0, length).
    """
    if length <= 1:
        raise ValueError("The length of the permuted range must be greater than 1.")

    if index not in range(length):
        raise ValueError("The index to permute must be in the range [0, l).")

    if seed < 0:
        raise ValueError("The permutation seed must be greater than or equal to 0.")

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")

        w = length - 1
        w |= w >> 1
        w |= w >> 2
        w |= w >> 4
        w |= w >> 8
        w |= w >> 16

        while True:
            index ^= seed
            index *= 0xE170893D
            index ^= seed >> 16
            index ^= (index & w) >> 4
            index ^= seed >> 8
            index *= 0x0929EB3F
            index ^= seed >> 23
            index ^= (index & w) >> 1
            index *= 1 | seed >> 27
            index *= 0x6935FA69
            index ^= (index & w) >> 11
            index *= 0x74DCB303
            index ^= (index & w) >> 2
            index *= 0x9E501CC3
            index ^= (index & w) >> 2
            index *= 0xC860A3DF
            index &= w
            if index < length:
                break

    return (index + seed) % length