自定义变换

在计算机视觉中使用 Datasets, Pipeline, TfmdListsTransform

概述

from fastai.vision.all import *

创建自己的 Transform

创建自己的 Transform 比你想象的要容易得多。事实上,每次你将标签函数传递给数据块 API 或 [ImageDataLoaders.from_name_func](https://docs.fastai.net.cn/vision.data.html#imagedataloaders.from_name_func) 时,你就已经在不知不觉中创建了一个 Transform。从本质上讲,Transform 只是一个函数。让我们展示如何通过实现一个封装 [albumentations 库](https://github.com/albumentations-team/albumentations) 中数据增强的 Transform 来轻松添加一个变换。

首先,你需要安装 albumentations 库。如果需要,取消注释以下单元格来安装

# !pip install albumentations

然后,在一个比我们之前使用的 mnist 图像更大的彩色图像上查看变换结果会更容易,因此我们从 PETS 数据集中加载一些数据。

source = untar_data(URLs.PETS)
items = get_image_files(source/"images")

我们仍然可以使用 PILIlmage.create 打开它

img = PILImage.create(items[0])
img

我们将展示如何封装一个变换,但你也可以轻松地封装你在 Compose 方法中封装的任何一组变换。这里我们来做一些 ShiftScaleRotate

from albumentations import ShiftScaleRotate

albumentations 变换处理 numpy 图像,所以我们只需将我们的 [PILImage](https://docs.fastai.net.cn/vision.core.html#pilimage) 转换为 numpy 数组,然后再使用 PILImage.create 重新封装(此函数接受文件名以及数组或张量)。

aug = ShiftScaleRotate(p=1)
def aug_tfm(img): 
    np_img = np.array(img)
    aug_img = aug(image=np_img)['image']
    return PILImage.create(aug_img)
aug_tfm(img)

每次需要 Transform 时,我们都可以传递这个函数,fastai 库会自动进行转换。这是因为你可以直接传递这样的函数来创建一个 Transform

tfm = Transform(aug_tfm)

如果你的变换需要维护一些状态,你可能需要创建一个 Transform 的子类。在这种情况下,你要应用的函数应该写在 encodes 方法中(就像你为 PyTorch 模块实现 forward 方法一样)

class AlbumentationsTransform(Transform):
    def __init__(self, aug): self.aug = aug
    def encodes(self, img: PILImage):
        aug_img = self.aug(image=np.array(img))['image']
        return PILImage.create(aug_img)

我们还添加了类型注解:这将确保此变换仅应用于 [PILImage](https://docs.fastai.net.cn/vision.core.html#pilimage) 及其子类。对于任何其他对象,它不会做任何事情。你也可以编写任意数量带有不同类型注解的 encodes 方法,Transform 会正确地分派它接收到的对象。

这是因为在实践中,变换通常作为 item_tfms(或 batch_tfms)应用于你在数据块 API 中传递的数据。这些数据是不同类型的对象组成的元组,变换可能对元组的每个部分有不同的行为。

让我们在这里看看它是如何工作的

tfm = AlbumentationsTransform(ShiftScaleRotate(p=1))
a,b = tfm((img, 'dog'))
show_image(a, title=b);

变换应用于元组 (img, "dog")img 是一个 [PILImage](https://docs.fastai.net.cn/vision.core.html#pilimage),因此它应用了我们编写的 encodes 方法。"dog" 是一个字符串,所以变换没有对其做任何操作。

然而,有时你需要变换将整个元组作为输入:例如,albumentations 同时应用于图像和分割掩码。在这种情况下,你需要继承 ItemTransfrom 而不是 Transform。让我们看看它是如何工作的

cv_source = untar_data(URLs.CAMVID_TINY)
cv_items = get_image_files(cv_source/'images')
img = PILImage.create(cv_items[0])
mask = PILMask.create(cv_source/'labels'/f'{cv_items[0].stem}_P{cv_items[0].suffix}')
ax = img.show()
ax = mask.show(ctx=ax)

然后我们编写一个 ItemTransform 的子类,它可以封装任何 albumentations 增强变换,但仅适用于分割问题

class SegmentationAlbumentationsTransform(ItemTransform):
    def __init__(self, aug): self.aug = aug
    def encodes(self, x):
        img,mask = x
        aug = self.aug(image=np.array(img), mask=np.array(mask))
        return PILImage.create(aug["image"]), PILMask.create(aug["mask"])

我们可以检查它如何应用于元组 (img, mask)。这意味着你可以在任何分割问题中将其作为 item_tfms 传递。

tfm = SegmentationAlbumentationsTransform(ShiftScaleRotate(p=1))
a,b = tfm((img, mask))
ax = a.show()
ax = b.show(ctx=ax)

分割

通过在 after_item 中使用相同的变换,但目标类型不同(此处是分割掩码),目标会通过类型分派系统自动得到正确处理。

cv_source = untar_data(URLs.CAMVID_TINY)
cv_items = get_image_files(cv_source/'images')
cv_splitter = RandomSplitter(seed=42)
cv_split = cv_splitter(cv_items)
cv_label = lambda o: cv_source/'labels'/f'{o.stem}_P{o.suffix}'
class ImageResizer(Transform):
    order=1
    "Resize image to `size` using `resample`"
    def __init__(self, size, resample=BILINEAR):
        if not is_listy(size): size=(size,size)
        self.size,self.resample = (size[1],size[0]),resample

    def encodes(self, o:PILImage): return o.resize(size=self.size, resample=self.resample)
    def encodes(self, o:PILMask):  return o.resize(size=self.size, resample=NEAREST)
tfms = [[PILImage.create], [cv_label, PILMask.create]]
cv_dsets = Datasets(cv_items, tfms, splits=cv_split)
dls = cv_dsets.dataloaders(bs=64, after_item=[ImageResizer(128), ToTensor(), IntToFloatTensor()])

如果我们想使用之前创建的增强变换,我们只需要添加一件事:我们希望它只应用于训练集,而不是验证集。为此,我们通过添加 split_idx=0 来指定它只应用于我们数据划分中的特定 idx(0 代表训练集,1 代表验证集)

class SegmentationAlbumentationsTransform(ItemTransform):
    split_idx = 0
    def __init__(self, aug): self.aug = aug
    def encodes(self, x):
        img,mask = x
        aug = self.aug(image=np.array(img), mask=np.array(mask))
        return PILImage.create(aug["image"]), PILMask.create(aug["mask"])

我们可以检查它如何应用于元组 (img, mask)。这意味着你可以在任何分割问题中将其作为 item_tfms 传递。

cv_dsets = Datasets(cv_items, tfms, splits=cv_split)
dls = cv_dsets.dataloaders(bs=64, after_item=[ImageResizer(128), ToTensor(), IntToFloatTensor(), 
                                              SegmentationAlbumentationsTransform(ShiftScaleRotate(p=1))])
dls.show_batch(max_n=4)

使用不同的变换流水线和 DataBlock API

我们通常会对训练数据集和验证数据集使用不同的变换。目前我们的 AlbumentationsTransform 会对两者执行相同的变换,让我们看看是否能让它更灵活一些,满足我们的需求。

让我们为我们的例子设想一个场景

我希望各种数据增强,例如 HueSaturationValue 或 [Flip](https://docs.fastai.net.cn/vision.augment.html#flip),其操作方式类似于 fastai 的做法,即只在训练数据集上运行,而在验证数据集上不运行。我们需要对我们的 AlbumentationsTransform 做些什么?

class AlbumentationsTransform(DisplayedTransform):
    split_idx,order=0,2
    def __init__(self, train_aug): store_attr()
    
    def encodes(self, img: PILImage):
        aug_img = self.train_aug(image=np.array(img))['image']
        return PILImage.create(aug_img)

这是我们新编写的变换。但是有什么变化呢?

我们添加了 split_idx,它决定了在验证集和训练集上运行哪些变换(0 代表训练,1 代表验证,None 代表两者都运行)。

除此之外,我们将 order 设置为 2。这意味着如果存在任何执行大小调整操作的 fastai 变换,它们将在我们的新变换之前完成。这让我们清楚地知道我们的变换何时应用以及如何使用它!

让我们来看一个包含一些 Composed albumentations 变换的例子

import albumentations
def get_train_aug(): return albumentations.Compose([
            albumentations.HueSaturationValue(
                hue_shift_limit=0.2, 
                sat_shift_limit=0.2, 
                val_shift_limit=0.2, 
                p=0.5
            ),
            albumentations.CoarseDropout(p=0.5),
            albumentations.Cutout(p=0.5)
])

我们可以使用 [Resize](https://docs.fastai.net.cn/vision.augment.html#resize) 和我们新的训练增强来定义我们的 ItemTransforms

item_tfms = [Resize(224), AlbumentationsTransform(get_train_aug())]

这次我们使用更高层的 [DataBlock](https://docs.fastai.net.cn/data.block.html#datablock) API

path = untar_data(URLs.PETS)/'images'

def is_cat(x): return x[0].isupper()
dls = ImageDataLoaders.from_name_func(
    path, get_image_files(path), valid_pct=0.2, seed=42,
    label_func=is_cat, item_tfms=item_tfms)

并查看一些数据

dls.train.show_batch(max_n=4)

dls.valid.show_batch(max_n=4)

我们可以看到我们的变换成功地只应用于了训练数据!太棒了!

现在,如果我们希望对训练集和验证集都应用特殊的、不同的行为怎么办?让我们看看

class AlbumentationsTransform(RandTransform):
    "A transform handler for multiple `Albumentation` transforms"
    split_idx,order=None,2
    def __init__(self, train_aug, valid_aug): store_attr()
    
    def before_call(self, b, split_idx):
        self.idx = split_idx
    
    def encodes(self, img: PILImage):
        if self.idx == 0:
            aug_img = self.train_aug(image=np.array(img))['image']
        else:
            aug_img = self.valid_aug(image=np.array(img))['image']
        return PILImage.create(aug_img)

那么让我们来一步步看看这里发生了什么。我们将 split_idx 改为 None,这允许我们在设置 split_idx 时进行控制。

我们还继承了 [RandTransform](https://docs.fastai.net.cn/vision.augment.html#randtransform),这允许我们在 before_call 中设置 split_idx

最后,我们检查当前的 split_idx 是多少。如果是 0,则运行训练增强,否则运行验证增强。

让我们看一个典型的训练设置示例

def get_train_aug(): return albumentations.Compose([
            albumentations.RandomResizedCrop(224,224),
            albumentations.Transpose(p=0.5),
            albumentations.VerticalFlip(p=0.5),
            albumentations.ShiftScaleRotate(p=0.5),
            albumentations.HueSaturationValue(
                hue_shift_limit=0.2, 
                sat_shift_limit=0.2, 
                val_shift_limit=0.2, 
                p=0.5),
            albumentations.CoarseDropout(p=0.5),
            albumentations.Cutout(p=0.5)
])

def get_valid_aug(): return albumentations.Compose([
    albumentations.CenterCrop(224,224, p=1.),
    albumentations.Resize(224,224)
], p=1.)

接下来我们将构建新的 AlbumentationsTransform

item_tfms = [Resize(256), AlbumentationsTransform(get_train_aug(), get_valid_aug())]

并将其传递给我们的 [DataLoaders](https://docs.fastai.net.cn/data.core.html#dataloaders):> 由于我们已经在组合变换中声明了大小调整,这里不再需要任何 item 变换

dls = ImageDataLoaders.from_name_func(
    path, get_image_files(path), valid_pct=0.2, seed=42,
    label_func=is_cat, item_tfms=item_tfms)

我们可以再次比较训练和验证增强,会发现它们确实不同

dls.train.show_batch(max_n=4)

dls.valid.show_batch(max_n=4)

查看验证 [DataLoader](https://docs.fastai.net.cn/data.load.html#dataloader) 中 x 的形状,我们还会发现 CenterCrop 也已应用

x,_ = dls.valid.one_batch()
print(x.shape)
(64, 3, 224, 224)
注意

我们首先使用了 fastai 的裁剪,因为有些图像尺寸太小,需要进行填充。