数据加载器

模块:polygraphy.comparator

class DataLoader(seed=None, iterations=None, input_metadata=None, int_range=None, float_range=None, val_range=None, data_loader_backend_module=None)[source]

基类: object

生成合成输入数据。

参数:
  • seed (int) – 生成随机输入时使用的种子。默认为 util.constants.DEFAULT_SEED

  • iterations (int) – 提供数据的迭代次数。默认为 1。

  • input_metadata (TensorMetadata) – 输入名称到其对应形状和数据类型的映射。这将用于确定为具有动态形状的输入提供什么形状,以及设置生成输入的数据类型。如果 dtype 或 shape 为 None,则将自动确定值。对于输入形状张量,即描述模型中形状的输入,提供的形状将用于填充输入的值,而不是确定其形状。

  • val_range (Union[Tuple[number], Dict[str, Tuple[number]]]) – 一个包含正好 2 个数字的元组,指示数据加载器应生成的最小值和最大值(包括)。如果元组中的任一值为 None,则将为该值使用默认值。如果提供 None 而不是元组,则最小值和最大值都将使用默认值。这可以在每个输入的基础上使用字典指定。在这种情况下,使用空字符串 (“”) 作为键来指定未明确列出的输入的默认范围。默认为 (0.0, 1.0)。

  • data_loader_backend_module (str) – 一个字符串,表示用于构造输入数据数组的模块。当前支持 “numpy” 和 “torch”。默认为 “numpy”。

  • int_range (Tuple[int]) – [已弃用 - 请改用 val_range] 一个包含正好 2 个整数的元组,指示数据加载器应生成的最小和最大整数值(包括)。如果元组中的任一值为 None,则将为该值使用默认值。如果提供 None 而不是元组,则最小值和最大值都将使用默认值。

  • float_range (Tuple[float]) – [已弃用 - 请改用 val_range] 一个包含正好 2 个浮点数的元组,指示数据加载器应生成的最小和最大浮点数值(包括)。如果元组中的任一值为 None,则将为该值使用默认值。如果提供 None 而不是元组,则最小值和最大值都将使用默认值。

__getitem__(index)[source]

生成随机输入数据。

可能会更新 DataLoader 的 input_metadata 属性。

参数:

index (int) – 由于此类行为类似于可迭代对象,因此它接受索引参数。对于相同的索引,保证生成的数据相同。

返回:

输入名称到输入 numpy 缓冲区的映射。

返回类型:

OrderedDict[str, Union[numpy.ndarray, torch.Tensor]]