跳到内容

实用程序

safe_index(tensor, index, device)

使用给定索引安全地索引张量,并在指定设备上返回结果。

注意,可以使用 return tensor[index.to(tensor.device)].to(device) 实现强制转换,但迁移成本很高。

参数

名称 类型 描述 默认值
张量 张量

要索引的张量。

必需
索引 张量

用于索引张量的索引。

必需
设备 设备

结果应返回的设备。

必需

返回值

名称 类型 描述
张量

指定设备上已索引的张量。

引发

类型 描述
ValueError

如果张量、索引和设备不在同一设备上。

源代码位于 bionemo/moco/interpolants/discrete_time/utils.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def safe_index(tensor: Tensor, index: Tensor, device: torch.device):
    """Safely indexes a tensor using a given index and returns the result on a specified device.

    Note can implement forcing with  return tensor[index.to(tensor.device)].to(device) but has costly migration.

    Args:
        tensor (Tensor): The tensor to be indexed.
        index (Tensor): The index to use for indexing the tensor.
        device (torch.device): The device on which the result should be returned.

    Returns:
        Tensor: The indexed tensor on the specified device.

    Raises:
        ValueError: If tensor, index, and device are not all on the same device.
    """
    if not (tensor.device == index.device == device):
        raise ValueError(
            f"Tensor, index, and device must all be on the same device. "
            f"Got tensor.device={tensor.device}, index.device={index.device}, and device={device}."
        )

    return tensor[index]