重要提示
您正在查看 NeMo 2.0 文档。此版本引入了 API 的重大更改和一个新的库 NeMo Run。我们目前正在将 NeMo 1.0 的所有功能移植到 2.0。有关先前版本或 2.0 中尚不可用的功能的文档,请参阅 NeMo 24.07 文档。
NSFW 分类器#
概述#
不适宜工作场所观看 (NSFW) 分类器确定图像包含性露骨材料的可能性。 NeMo Curator 与 基于 CLIP 的 NSFW 检测器 集成,后者输出一个介于 0 和 1 之间的值,其中 1 表示内容为 NSFW。
用例#
在大多数数据处理管道中,删除不安全内容是很常见的做法,以防止您的生成式 AI 模型学习生成不安全材料。例如,Data Comp 在进行实验之前会过滤掉 NSFW 内容。
先决条件#
请务必查看图像策展入门页面,以安装您需要的一切。
用法#
NSFW 分类器是一个小型 MLP 分类器,它将 OpenAI CLIP ViT-L/14 图像嵌入作为输入。此模型可通过 vit_large_patch14_clip_quickgelu_224.openai
标识符在 TimmImageEmbedder
中获得。首先,我们可以计算这些嵌入,然后我们可以执行分类。
from nemo_curator import get_client
from nemo_curator.datasets import ImageTextPairDataset
from nemo_curator.image.embedders import TimmImageEmbedder
from nemo_curator.image.classifiers import NsfwClassifier
client = get_client(cluster_type="gpu")
dataset = ImageTextPairDataset.from_webdataset(path="/path/to/dataset", id_col="key")
embedding_model = TimmImageEmbedder(
"vit_large_patch14_clip_quickgelu_224.openai",
pretrained=True,
batch_size=1024,
num_threads_per_worker=16,
normalize_embeddings=True,
)
safety_classifier = NsfwClassifier()
dataset_with_embeddings = embedding_model(dataset)
dataset_with_nsfw_scores = safety_classifier(dataset_with_embeddings)
# Metadata will have a new column named "nsfw_score"
dataset_with_nsfw_scores.save_metadata()
关键参数#
batch_size=-1
是可选的批量大小参数。默认情况下,它将一次处理分片中的所有嵌入。由于 NSFW 分类器是一个小型模型,因此这通常是可以的。
性能注意事项#
由于 NSFW 模型非常小,您可以将其加载到 GPU 上,与嵌入模型同时进行,并在计算嵌入后直接执行推理。查看此示例
from nemo_curator import get_client
from nemo_curator.datasets import ImageTextPairDataset
from nemo_curator.image.embedders import TimmImageEmbedder
from nemo_curator.image.classifiers import NsfwClassifier
client = get_client(cluster_type="gpu")
dataset = ImageTextPairDataset.from_webdataset(path="/path/to/dataset", id_col="key")
embedding_model = TimmImageEmbedder(
"vit_large_patch14_clip_quickgelu_224.openai",
pretrained=True,
batch_size=1024,
num_threads_per_worker=16,
normalize_embeddings=True,
classifiers=[NsfwClassifier()],
)
dataset_with_nsfw_scores = embedding_model(dataset)
# Metadata will have a new column named "nsfw_score"
dataset_with_nsfw_scores.save_metadata()