跳到内容

工具

sample_or_truncate(gene_ids, max_length, sample=True)

截断和填充样本。

参数

名称 类型 描述 默认值
gene_ids ndarray

基因 ID 数组。

必需
max_length int

样本的最大长度。

必需
sample bool

是否对样本进行采样或截断。默认为 True。

True

返回值

类型 描述
ndarray

np.array: 包含截断或填充后的基因 ID 的元组。

源代码位于 bionemo/geneformer/data/singlecell/utils.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def sample_or_truncate(
    gene_ids: np.ndarray,
    max_length: int,
    sample: bool = True,
) -> np.ndarray:
    """Truncate and pad samples.

    Args:
        gene_ids (np.ndarray): Array of gene IDs.
        max_length (int): Maximum length of the samples.
        sample (bool, optional): Whether to sample or truncate the samples. Defaults to True.

    Returns:
        np.array: Tuple containing the truncated or padded gene IDs.
    """
    if len(gene_ids) <= max_length:
        return gene_ids

    if sample:
        indices = np.random.permutation(len(gene_ids))[:max_length]
        return gene_ids[indices]
    else:
        return gene_ids[:max_length]