条件式执行和掩码#

本教程曾用于展示在 DALI 中利用算术表达式实现条件结果的两种方法。

DALI 支持条件执行,允许使用 if 语句和标量条件在选定的样本上有条件地运行操作,这现在是推荐的方法。使用适当的条件执行既更高效又更节省内存。您可以在 Pipeline 文档的条件执行部分条件教程中阅读更多相关信息。这种方法取代了以前可以通过算术运算模拟的方法。

如果条件具有更多维度,我们可以对这些数据进行按位运算,以模拟逻辑表达式和每个像素或张量元素的条件执行。在这种情况下,所有子表达式都需要预先评估。我们将在本教程中展示这种方法 - 输出图像将根据掩码的值生成,其中 bool 控制每个输出像素。

使用比较和按位运算生成掩码#

我们将使用比较运算符来构建掩码,这些掩码表示图像具有低像素强度和高像素强度的区域。

首先,我们将计算亮度调整后的图像。接下来,我们为像素和像素构建掩码。如果最亮通道低于某个阈值,则像素被认为是暗像素。类似地,如果其通道的最暗通道高于某个阈值,则像素被认为是亮像素。掩码是通过计算通道维度上的最大值和最小值,并将其与低阈值和高阈值进行比较而获得的。我们获得内部具有 bool 值的掩码。

我们使用按位 OR 运算来构建表示低强度区域和高强度区域并集的掩码。掩码中的值是布尔值,因此按位 |& ^ 运算可以像它们的逻辑对应物一样以元素方式使用。

DALI 算术表达式是元素级的并支持广播。我们可以使用乘法和加法来构建结果图像。这种方法类似于多路复用

即使掩码是 1 通道的 - 形状为 (H, W, 1),而图像有 3 个通道,形状为 (H, W, 3),由于通道维度的自动广播,我们可以将它们相乘。您可以在文档的此部分中阅读有关广播的更多信息。

将图像乘以布尔掩码会将图像中对应于掩码中 False 值的区域归零。由于掩码是不相交的,我们可以通过将增强图像加在一起来组合它们。

请记住,我们首先必须生成所有像素中都变亮和变暗的图像 - 结果是通过根据掩码从这些输入中选择原始值、变亮值或变暗值来构建的 - 与对单个样本使用条件执行相反,这里不涉及任何类型的局部执行。

[1]:
from nvidia.dali import pipeline_def
import nvidia.dali.fn as fn
import nvidia.dali.types as types
from nvidia.dali.types import Constant

import matplotlib.pyplot as plt
import numpy as np
[2]:
def not_(mask):
    """Emulate logical not operation on tensor data."""
    return True ^ mask


def expand_mask(mask):
    """Expand 1-channel mask into image represented with 3 channels for
    the purpose of displaying it."""
    return fn.cat(mask, mask, mask, axis=2)


@pipeline_def(batch_size=5, num_threads=1, device_id=0)
def masking_pipe():
    input_buf, _ = fn.readers.file(
        device="cpu",
        file_root="../../data/images",
        file_list="../../data/images/file_list.txt",
    )
    imgs = fn.decoders.image(input_buf, device="cpu", output_type=types.RGB)

    imgs_gray = fn.color_space_conversion(
        imgs, image_type=types.RGB, output_type=types.GRAY
    )
    imgs_bright = fn.brightness_contrast(imgs, brightness=3)
    imgs_dark = fn.brightness_contrast(imgs, brightness=0.75)

    mask_low = fn.reductions.max(imgs_gray, axes=-1, keep_dims=True) < 30
    mask_high = fn.reductions.min(imgs_gray, axes=-1, keep_dims=True) > 230

    mask_other = not_(mask_low | mask_high)

    out = mask_low * imgs_bright + mask_high * imgs_dark + mask_other * imgs

    return out, imgs, expand_mask(mask_other * Constant(255).uint8())
[3]:
mask_pipe = masking_pipe()
mask_pipe.build()

让我们显示结果图像:增强后的图像和原始图像,以及亮度调整区域为黑色的掩码。

[4]:
def display(augmented, reference, mask, cpu=True):
    data_idx = 0
    fig, axes = plt.subplots(len(augmented), 3, figsize=(15, 15))
    for i in range(len(augmented)):
        img = augmented.at(i) if cpu else augmented.as_cpu().at(i)
        ref = reference.at(i) if cpu else reference.as_cpu().at(i)
        m = mask.at(i) if cpu else mask.as_cpu().at(i)
        axes[i, 0].imshow(np.squeeze(img))
        axes[i, 1].imshow(np.squeeze(ref))
        axes[i, 2].imshow(np.squeeze(m))
        axes[i, 0].axis("off")
        axes[i, 1].axis("off")
        axes[i, 2].axis("off")
        axes[i, 0].set_title("Augmented image")
        axes[i, 1].set_title("Reference decoded image")
        axes[i, 2].set_title("Calculated mask")
[5]:
output, reference, mask = mask_pipe.run()
display(output, reference, mask)
../../../_images/examples_general_expressions_expr_conditional_and_masking_7_0.png