重要
您正在查看 NeMo 2.0 文档。此版本对 API 和一个新的库 NeMo Run 进行了重大更改。我们目前正在将 NeMo 1.0 的所有功能移植到 2.0。有关先前版本或 2.0 中尚不可用的功能的文档,请参阅 NeMo 24.07 文档。
分类器#
基类#
- class nemo_curator.image.classifiers.ImageClassifier(
- model_name: str,
- embedding_column: str,
- pred_column: str,
- pred_type: str | type,
- batch_size: int,
- embedding_size: int,
一个抽象基类,表示基于 CLIP 视觉编码器生成的嵌入之上的分类器。
子类只需要定义如何加载模型。如果他们想在预测输出序列合并到数据集之前对其进行修改,他们也可以覆盖 postprocess 方法。分类器必须能够在单个 GPU 上运行。
- abstract load_model(device: str) Callable #
加载分类器模型。
- 参数:
device (str) – 一个 PyTorch 设备标识符,指定在哪个 GPU 上加载模型。
- 返回:
- 一个可调用模型,通常是 torch.nn.Module。
此模型的输入将是由 ImageEmbedder.load_dataset_shard 输出的图像批次。
- 返回类型:
Callable
- postprocess(series: cudf.Series) cudf.Series #
在将分类器的预测保存到元数据之前对其进行后处理。
- 参数:
series (cudf.Series) – 原始模型预测的 cuDF 系列。
- 返回:
- 相同的系列,未修改。如果需要,在您的分类器中覆盖
if needed.
- 返回类型:
cudf.Series
图像分类器#
- class nemo_curator.image.classifiers.AestheticClassifier(
- embedding_column: str = 'image_embedding',
- pred_column: str = 'aesthetic_score',
- batch_size: int = -1,
- model_path: str | None = None,
LAION-Aesthetics_Predictor V2 是一个线性分类器,基于 OpenAI CLIP ViT-L/14 图像嵌入进行训练。它用于评估图像的美学质量。有关该模型的更多信息,请访问:https://laion.ai/blog/laion-aesthetics/。
- load_model(device)#
加载分类器模型。
- 参数:
device (str) – 一个 PyTorch 设备标识符,指定在哪个 GPU 上加载模型。
- 返回:
- 一个可调用模型,通常是 torch.nn.Module。
此模型的输入将是由 ImageEmbedder.load_dataset_shard 输出的图像批次。
- 返回类型:
Callable
- postprocess(series)#
在将分类器的预测保存到元数据之前对其进行后处理。
- 参数:
series (cudf.Series) – 原始模型预测的 cuDF 系列。
- 返回:
- 相同的系列,未修改。如果需要,在您的分类器中覆盖
if needed.
- 返回类型:
cudf.Series
- class nemo_curator.image.classifiers.NsfwClassifier(
- embedding_column: str = 'image_embedding',
- pred_column: str = 'nsfw_score',
- batch_size: int = -1,
- model_path: str | None = None,
NSFW 分类器是一个小型 MLP,基于 OpenAI 的 ViT-L CLIP 图像嵌入进行训练。它用于评估图像包含性暴露内容的可能性。有关该模型的更多信息,请访问:LAION-AI/CLIP-based-NSFW-Detector。
- load_model(device)#
加载分类器模型。
- 参数:
device (str) – 一个 PyTorch 设备标识符,指定在哪个 GPU 上加载模型。
- 返回:
- 一个可调用模型,通常是 torch.nn.Module。
此模型的输入将是由 ImageEmbedder.load_dataset_shard 输出的图像批次。
- 返回类型:
Callable
- postprocess(series)#
在将分类器的预测保存到元数据之前对其进行后处理。
- 参数:
series (cudf.Series) – 原始模型预测的 cuDF 系列。
- 返回:
- 相同的系列,未修改。如果需要,在您的分类器中覆盖
if needed.
- 返回类型:
cudf.Series