重要提示

您正在查看 NeMo 2.0 文档。此版本引入了对 API 和新库 NeMo Run 的重大更改。我们目前正在将 NeMo 1.0 中的所有功能移植到 2.0。有关先前版本或 2.0 中尚不可用的功能的文档,请参阅 NeMo 24.07 文档

图像嵌入器#

概述#

NeMo Curator 中的许多图像管理功能直接对图像嵌入而不是图像进行操作。图像嵌入器提供了一种可扩展的方式来为数据集中的每个图像生成嵌入。

用例#

  • 美学和 NSFW 分类都使用从 OpenAI 的 CLIP ViT-L 变体生成的图像嵌入。

  • 语义去重计算数据点的相似性。

先决条件#

请确保查看图像管理入门页面,以安装您将需要的一切。

Timm 图像嵌入器#

PyTorch 图像模型 (timm)是一个包含 SOTA 计算机视觉模型的库。这些模型中的许多模型在为 NeMo Curator 中的模块生成图像嵌入时非常有用。

from nemo_curator import get_client
from nemo_curator.datasets import ImageTextPairDataset
from nemo_curator.image.embedders import TimmImageEmbedder

client = get_client(cluster_type="gpu")

dataset = ImageTextPairDataset.from_webdataset(path="/path/to/dataset", id_col="key")

embedding_model = TimmImageEmbedder(
    "vit_large_patch14_clip_quickgelu_224.openai",
    pretrained=True,
    batch_size=1024,
    num_threads_per_worker=16,
    normalize_embeddings=True,
)

dataset_with_embeddings = embedding_model(dataset)

# Metadata will have a new column named "image_embedding"
dataset_with_embeddings.save_metadata()

在这里,我们加载一个数据集并使用 vit_large_patch14_clip_quickgelu_224.openai 计算图像嵌入。在该过程结束时,我们的元数据文件有一个名为“image_embedding”的新列,其中包含每个数据点的图像嵌入。

关键参数#

  • pretrained=True 确保您下载模型的预训练权重。

  • batch_size=1024 确定在每个单独的 GPU 上一次处理的图像数量。

  • num_threads_per_worker=16 确定 DALI 用于数据加载的线程数。

  • normalize_embeddings=True 将归一化每个嵌入。NeMo Curator 的分类器期望归一化的嵌入作为输入。

性能注意事项#

在底层,图像嵌入模型执行以下操作

  1. 下载模型的权重。

  2. 下载 PyTorch 图像转换(例如,调整大小和中心裁剪)。

  3. 将 PyTorch 图像转换转换为 DALI 转换。

  4. 使用 Dask-cuDF 将元数据的分片(.parquet 文件)加载到您可用的每个 GPU 上。

  5. 将模型的副本加载到每个 GPU 上。

  6. 使用 DALI,以给定的每个工作线程的线程数 (num_threads_per_worker) 反复将大小为 batch_size 的批次加载到每个 GPU 中。

  7. 模型在批次上运行(没有 torch.autocast(),因为 autocast=False)。

  8. 模型的输出嵌入已归一化,因为 normalize_embeddings=True

此流程中有几个关键的性能注意事项。

  • 您必须拥有满足要求的 NVIDIA GPU。

  • 您可以在 tar 文件的同一目录中创建 .idx 文件,以加快数据加载时间。有关更多信息,请参阅 DALI 文档

自定义图像嵌入器#

要编写您自己的自定义嵌入器,您可以从 nemo_curator.image.embedders.ImageEmbedder 继承并覆盖如下所示的两种方法

from nemo_curator.image.embedders import ImageEmbedder

class MyCustomEmbedder(ImageEmbedder):

    def load_dataset_shard(self, tar_path: str) -> Iterable:
        # Implement me!
        pass

    def load_embedding_model(self, device: str) -> Callable:
        # Implement me!
        pass
  • load_dataset_shard() 将接收 tar 文件的路径,并返回分片上的可迭代对象。可迭代对象应返回 (数据批次, 元数据) 的元组。数据批次可以是任何形式。它将直接传递给 load_embedding_model() 返回的模型。元数据应是元数据的字典,其中包含与数据集的 id_col 对应的字段。在我们的示例中,元数据应包含 "key" 的值。

  • load_embedding_model() 将接受设备并返回可调用对象。此可调用对象将以 load_dataset_shard() 生成的数据批次作为输入。

其他资源#