实用程序
safe_index(tensor, index, device)
使用给定索引安全地索引张量,并在指定设备上返回结果。
注意,可以使用 return tensor[index.to(tensor.device)].to(device) 实现强制转换,但迁移成本很高。
参数
名称 | 类型 | 描述 | 默认值 |
---|---|---|---|
张量
|
张量
|
要索引的张量。 |
必需 |
索引
|
张量
|
用于索引张量的索引。 |
必需 |
设备
|
设备
|
结果应返回的设备。 |
必需 |
返回值
名称 | 类型 | 描述 |
---|---|---|
张量 |
指定设备上已索引的张量。 |
引发
类型 | 描述 |
---|---|
ValueError
|
如果张量、索引和设备不在同一设备上。 |
源代码位于 bionemo/moco/interpolants/discrete_time/utils.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
|