数据加载器
模块:polygraphy.comparator
- class DataLoader(seed=None, iterations=None, input_metadata=None, int_range=None, float_range=None, val_range=None, data_loader_backend_module=None)[source]
基类:
object
生成合成输入数据。
- 参数:
seed (int) – 生成随机输入时使用的种子。默认为
util.constants.DEFAULT_SEED
。iterations (int) – 提供数据的迭代次数。默认为 1。
input_metadata (TensorMetadata) – 输入名称到其对应形状和数据类型的映射。这将用于确定为具有动态形状的输入提供什么形状,以及设置生成输入的数据类型。如果 dtype 或 shape 为 None,则将自动确定值。对于输入形状张量,即值描述模型中形状的输入,提供的形状将用于填充输入的值,而不是确定其形状。
val_range (Union[Tuple[number], Dict[str, Tuple[number]]]) – 一个包含正好 2 个数字的元组,指示数据加载器应生成的最小值和最大值(包括)。如果元组中的任一值为 None,则将为该值使用默认值。如果提供 None 而不是元组,则最小值和最大值都将使用默认值。这可以在每个输入的基础上使用字典指定。在这种情况下,使用空字符串 (“”) 作为键来指定未明确列出的输入的默认范围。默认为 (0.0, 1.0)。
data_loader_backend_module (str) – 一个字符串,表示用于构造输入数据数组的模块。当前支持 “numpy” 和 “torch”。默认为 “numpy”。
int_range (Tuple[int]) – [已弃用 - 请改用 val_range] 一个包含正好 2 个整数的元组,指示数据加载器应生成的最小和最大整数值(包括)。如果元组中的任一值为 None,则将为该值使用默认值。如果提供 None 而不是元组,则最小值和最大值都将使用默认值。
float_range (Tuple[float]) – [已弃用 - 请改用 val_range] 一个包含正好 2 个浮点数的元组,指示数据加载器应生成的最小和最大浮点数值(包括)。如果元组中的任一值为 None,则将为该值使用默认值。如果提供 None 而不是元组,则最小值和最大值都将使用默认值。