重要提示

您正在查看 NeMo 2.0 文档。此版本引入了 API 的重大更改和一个新的库 NeMo Run。我们目前正在将 NeMo 1.0 的所有功能移植到 2.0。有关先前版本或 2.0 中尚不可用的功能的文档,请参阅 NeMo 24.07 文档

NSFW 分类器#

概述#

不适宜工作场所观看 (NSFW) 分类器确定图像包含性露骨材料的可能性。 NeMo Curator 与 基于 CLIP 的 NSFW 检测器 集成,后者输出一个介于 0 和 1 之间的值,其中 1 表示内容为 NSFW。

用例#

在大多数数据处理管道中,删除不安全内容是很常见的做法,以防止您的生成式 AI 模型学习生成不安全材料。例如,Data Comp 在进行实验之前会过滤掉 NSFW 内容。

先决条件#

请务必查看图像策展入门页面,以安装您需要的一切。

用法#

NSFW 分类器是一个小型 MLP 分类器,它将 OpenAI CLIP ViT-L/14 图像嵌入作为输入。此模型可通过 vit_large_patch14_clip_quickgelu_224.openai 标识符在 TimmImageEmbedder 中获得。首先,我们可以计算这些嵌入,然后我们可以执行分类。

from nemo_curator import get_client
from nemo_curator.datasets import ImageTextPairDataset
from nemo_curator.image.embedders import TimmImageEmbedder
from nemo_curator.image.classifiers import NsfwClassifier

client = get_client(cluster_type="gpu")

dataset = ImageTextPairDataset.from_webdataset(path="/path/to/dataset", id_col="key")

embedding_model = TimmImageEmbedder(
    "vit_large_patch14_clip_quickgelu_224.openai",
    pretrained=True,
    batch_size=1024,
    num_threads_per_worker=16,
    normalize_embeddings=True,
)
safety_classifier = NsfwClassifier()

dataset_with_embeddings = embedding_model(dataset)
dataset_with_nsfw_scores = safety_classifier(dataset_with_embeddings)

# Metadata will have a new column named "nsfw_score"
dataset_with_nsfw_scores.save_metadata()

关键参数#

  • batch_size=-1 是可选的批量大小参数。默认情况下,它将一次处理分片中的所有嵌入。由于 NSFW 分类器是一个小型模型,因此这通常是可以的。

性能注意事项#

由于 NSFW 模型非常小,您可以将其加载到 GPU 上,与嵌入模型同时进行,并在计算嵌入后直接执行推理。查看此示例

from nemo_curator import get_client
from nemo_curator.datasets import ImageTextPairDataset
from nemo_curator.image.embedders import TimmImageEmbedder
from nemo_curator.image.classifiers import NsfwClassifier

client = get_client(cluster_type="gpu")

dataset = ImageTextPairDataset.from_webdataset(path="/path/to/dataset", id_col="key")

embedding_model = TimmImageEmbedder(
    "vit_large_patch14_clip_quickgelu_224.openai",
    pretrained=True,
    batch_size=1024,
    num_threads_per_worker=16,
    normalize_embeddings=True,
    classifiers=[NsfwClassifier()],
)

dataset_with_nsfw_scores = embedding_model(dataset)

# Metadata will have a new column named "nsfw_score"
dataset_with_nsfw_scores.save_metadata()

其他资源#