数据加载器

模块:polygraphy.tools.args

class DataLoaderArgs(allow_custom_input_shapes: bool | None = None)[source]

基类:BaseArgs

数据加载器:加载或生成用于推理的输入数据。

依赖于

  • ModelArgs:如果 allow_custom_input_shapes == True

参数

allow_custom_input_shapes (bool) – 是否允许在随机生成数据时使用自定义输入形状。默认为 True。

parse_impl(args)[source]
seed

用于随机数据生成的种子。

类型

int

val_range

要生成的每个输入的取值范围。

类型

Dict[str, Tuple[int]]

iterations

要生成数据的迭代次数。

类型

int

load_inputs_paths

从中加载输入的路径。

类型

List[str]

data_loader_script

加载输入的自定义脚本路径。

类型

str

data_loader_func_name

自定义数据加载器脚本中用于加载数据的函数名称。

类型

str

data_loader_backend_module

要使用的提供数组的模块。

类型

str

add_to_script_impl(script, user_input_metadata_str=None)[source]
参数

user_input_metadata_str (str(TensorMetadata)) – 包含 TensorMetadata 的变量的名称。这将控制生成数据的形状和数据类型。

返回值

数据加载器,以字符串形式表示。这可以是变量名,

或数据加载器函数的调用。

返回类型

str

get_data_loader(user_input_metadata=None)[source]

根据命令行中提供的参数创建数据加载器。

返回值

Sequence[OrderedDict[str, numpy.ndarray]]

is_using_random_data()[source]

此数据加载器是否将随机生成数据,而不是使用真实数据。

返回值

bool