重要提示
您正在查看 NeMo 2.0 文档。此版本引入了对 API 的重大更改和一个新的库,NeMo Run。我们目前正在将 NeMo 1.0 中的所有功能移植到 2.0。有关先前版本或 2.0 中尚不可用的功能的文档,请参阅 NeMo 24.07 文档。
DreamBooth#
模型介绍#
DreamBooth [MM-MODELS-DB2] 是一种微调技术,也是一种个性化大型扩散模型的解决方案,例如 Stable Diffusion,这些模型功能强大,但缺乏模仿给定参考集对象的能力。借助 DreamBooth,您只需要一些特定对象的图像来微调预训练的文本到图像模型,使其学习将唯一标识符与特殊对象绑定。然后,可以使用此唯一标识符来合成在不同场景中上下文化的对象的全新逼真图像。
NeMo 的 DreamBooth 构建于 Stable Diffusion 框架之上。虽然其架构与 Stable Diffusion 相似(请参阅模型配置),但区别在于其训练过程,尤其是在使用不同的数据集并在必要时结合先验保持损失时。
先验保持损失
在小数据集上微调大型预训练语言模型以用于特定任务或文本到图像扩散模型时,经常会出现诸如语言漂移和输出多样性降低等问题。先验保持损失的概念很简单:它使用模型的自生成样本来指导模型,并结合模型预测的这些样本噪声之间的差异。可以使用 model.prior_loss_weight 调整此损失分量的影响。
model_pred, model_pred_prior = torch.chunk(model_output, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="mean")
prior_loss = torch.nn.functional.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
loss = loss + prior_loss * self.prior_loss_weight
训练数据集
NeMo 的 DreamBooth 模型数据集与其他 NeMo 多模态模型不同,因为它不需要以 webdataset 格式存储数据。您可以在 [MM-MODELS-DB1] 找到示例数据集。对于您想要集成到模型中的每个对象,只需将其图像(通常为 3-5 张)放在一个文件夹中,并在
model.data.instance_dir
中指定其路径。当使用先验保持损失进行训练时,将原始模型生成的图像存储在一个不同的文件夹中,并在model.data.regularization_dir
中引用其路径。此过程在 NeMo 的 DreamBooth 实现中是自动化的。
模型配置#
有关如何配置 Stable Diffusion,请参阅模型配置。这里我们展示 DreamBooth 特定的配置。
先验保持损失#
model:
with_prior_preservation: False
prior_loss_weight: 0.5
train_text_encoder: False
restore_from_path: /ckpts/nemo-v1-5-188000-ema.nemo #This ckpt is only used to generate regularization images, thus .nemo ckpt is needed
data:
instance_dir: /datasets/instance_dir
instance_prompt: a photo of a sks dog
regularization_dir: /datasets/nemo_dogs
regularization_prompt: a photo of a dog
num_reg_images: 10
num_images_per_prompt: 4
resolution: 512
center_crop: True
train_text_encoder
:指示是否应与 U-Net 一起微调文本编码器。with_prior_preservation
:根据其设置,这会影响模型在正则化数据方面的行为方式。如果设置为False
,则model.prior_loss_weight
和model.restore_from_path
都将被忽略。如果设置为True
,则操作将根据model.data.regularization_dir
中存在的图像数量而有所不同如果计数少于
model.data.num_reg_images
应为
model.restore_from_path
提供一个 .nemo 检查点,以允许推理管道生成正则化图像。model.data.num_images_per_prompt
类似于推理批处理大小,指示一次通过中生成的图像数量,受 GPU 功能限制。model.regularization_prompt
确定推理管道生成图像的文本提示。它通常是model.data.instance_prompt
的变体,减去唯一令牌。一旦满足上述所有参数,推理管道将运行,直到正则化目录中达到所需的图像计数。
如果计数匹配或超过
model.data.num_reg_images
训练将继续进行,而无需调用推理管道,并且上述参数将被忽略。
使用缓存的潜在空间进行训练#
model:
use_cached_latents: True
data:
num_workers: 4
instance_dir: /datasets/instance_dir
instance_prompt: a photo of a sks dog
regularization_dir: /datasets/nemo_dogs
regularization_prompt: a photo of a dog
cached_instance_dir: #/datasets/instance_dir_cached
cached_reg_dir: #/datasets/nemo_dogs_cached
use_cached_latents
:确定是使用在线编码还是预缓存的潜在空间进行训练。cached_instance_dir
:如果启用
use_cached_latents
并且指定了这些包含 .pt 格式潜在空间的目录,则训练将使用潜在空间而不是原始图像。如果未提供缓存目录或潜在空间文件数量与原始图像计数不匹配,则变分自动编码器将在训练前计算图像潜在空间,并将结果保存在磁盘上。
cached_reg_dir
:+ 该逻辑与上述一致,取决于 model.with_prior_preservation 设置。
参考#
Google. Dreambooth. 2023. URL: google/dreambooth。
Nataniel Ruiz, Yuanzhen Li, Varun Jampani, Yael Pritch, Michael Rubinstein, and Kfir Aberman. Dreambooth: fine tuning text-to-image diffusion models for subject-driven generation. 2022. URL: https://arxiv.org/abs/2208.12242。