GAN

GAN 是指 生成对抗网络 (Generative Adversarial Nets),由 Ian Goodfellow 发明。其概念是同时训练两个模型:生成器和判别器。生成器试图创建与数据集中图像相似的新图像,而判别器则试图区分真实图像和生成器创建的图像。生成器返回图像,判别器返回一个单一数字(通常是概率,假图像为 0,真实图像为 1)。

我们让他们相互对抗,其训练过程大致如下:

  1. 冻结生成器,训练判别器一步:
  1. 冻结判别器,训练生成器一步:
注意

fastai 库通过 GANTrainer 提供对 GAN 训练的支持,但不包含高级模型,只提供基本模型。

封装模块


GANModule

 GANModule (generator:nn.Module=None, critic:nn.Module=None,
            gen_mode:None|bool=False)

封装 generatorcritic 以创建 GAN。

类型 默认值 详情
生成器 Module None 生成器 PyTorch 模块
判别器 Module None 判别器 PyTorch 模块
生成模式 None | bool False GAN 是否应设置为生成模式

这只是一个包含两个模型的壳。调用时,它将根据 gen_mode 的值将输入委托给 generatorcritic


GANModule.switch

 GANModule.switch (gen_mode:None|bool=None)

如果 gen_modeTrue,则将模块置于生成模式,否则置于判别模式。

类型 默认值 详情
生成模式 None | bool None GAN 是否应设置为生成模式

默认情况下(将 gen_mode 设为 None),这将把模块置于另一种模式(如果处于生成模式则切换到判别模式,反之亦然)。


basic_critic

 basic_critic (in_size:int, n_channels:int, n_features:int=64,
               n_extra_layers:int=0, norm_type:NormType=<NormType.Batch:
               1>, ks=3, stride=1, padding=None, bias=None, ndim=2,
               bn_1st=True, act_cls=<class
               'torch.nn.modules.activation.ReLU'>, transpose=False,
               init='auto', xtra=None, bias_std=0.01,
               dilation:Union[int,Tuple[int,int]]=1, groups:int=1,
               padding_mode:str='zeros', device=None, dtype=None)

一个用于 n_channels x in_size x in_size 图像的基本判别器。

类型 默认值 详情
输入尺寸 int 判别器的输入尺寸(与生成器的输出尺寸相同)
通道数 int 判别器输入的通道数
特征数 int 64 判别器中使用的特征数量
额外层数 int 0 判别器中额外隐藏层的数量
归一化类型 NormType NormType.Batch 判别器中使用的归一化类型
核大小 int 3
步长 int 1
填充 NoneType None
偏差 NoneType None
维度 int 2
bn_1st bool True
激活函数类型 type ReLU
转置 bool False
初始化 str auto
额外参数 NoneType None
偏差标准差 float 0.01
膨胀 Union 1
分组 int 1
填充模式 str zeros TODO: 完善此类型
设备 NoneType None
数据类型 NoneType None
返回值 nn.Sequential

AddChannels

 AddChannels (n_dim)

在输入的末尾添加 n_dim 个通道。


basic_generator

 basic_generator (out_size:int, n_channels:int, in_sz:int=100,
                  n_features:int=64, n_extra_layers:int=0, ks=3, stride=1,
                  padding=None, bias=None, ndim=2,
                  norm_type=<NormType.Batch: 1>, bn_1st=True,
                  act_cls=<class 'torch.nn.modules.activation.ReLU'>,
                  transpose=False, init='auto', xtra=None, bias_std=0.01,
                  dilation:Union[int,Tuple[int,int]]=1, groups:int=1,
                  padding_mode:str='zeros', device=None, dtype=None)

一个从 in_szn_channels x out_size x out_size 图像的基本生成器。

类型 默认值 详情
输出尺寸 int 生成器的输出尺寸(与判别器的输入尺寸相同)
通道数 int 生成器输出的通道数
输入尺寸 int 100 生成器的输入噪声向量大小
特征数 int 64 生成器中使用的特征数量
额外层数 int 0 生成器中额外隐藏层的数量
核大小 int 3
步长 int 1
填充 NoneType None
偏差 NoneType None
维度 int 2
归一化类型 NormType NormType.Batch
bn_1st bool True
激活函数类型 type ReLU
转置 bool False
初始化 str auto
额外参数 NoneType None
偏差标准差 float 0.01
膨胀 Union 1
分组 int 1
填充模式 str zeros TODO: 完善此类型
设备 NoneType None
数据类型 NoneType None
返回值 nn.Sequential
critic = basic_critic(64, 3)
generator = basic_generator(64, 3)
tst = GANModule(critic=critic, generator=generator)
real = torch.randn(2, 3, 64, 64)
real_p = tst(real)
test_eq(real_p.shape, [2,1])

tst.switch() #tst is now in generator mode
noise = torch.randn(2, 100)
fake = tst(noise)
test_eq(fake.shape, real.shape)

tst.switch() #tst is back in critic mode
fake_p = tst(fake)
test_eq(fake_p.shape, [2,1])

DenseResBlock

 DenseResBlock (nf:int, norm_type:NormType=<NormType.Batch: 1>, ks=3,
                stride=1, padding=None, bias=None, ndim=2, bn_1st=True,
                act_cls=<class 'torch.nn.modules.activation.ReLU'>,
                transpose=False, init='auto', xtra=None, bias_std=0.01,
                dilation:Union[int,Tuple[int,int]]=1, groups:int=1,
                padding_mode:str='zeros', device=None, dtype=None)

具有 nf 特征的 Resnet 块。conv_kwargs 传递给 conv_layer

类型 默认值 详情
特征数 int 特征的数量
归一化类型 NormType NormType.Batch 归一化类型
核大小 int 3
步长 int 1
填充 NoneType None
偏差 NoneType None
维度 int 2
bn_1st bool True
激活函数类型 type ReLU
转置 bool False
初始化 str auto
额外参数 NoneType None
偏差标准差 float 0.01
膨胀 Union 1
分组 int 1
填充模式 str zeros TODO: 完善此类型
设备 NoneType None
数据类型 NoneType None
返回值 SequentialEx

gan_critic

 gan_critic (n_channels:int=3, nf:int=128, n_blocks:int=3, p:float=0.15)

用于训练 GAN 的判别器。

类型 默认值 详情
通道数 int 3 判别器输入的通道数
特征数 int 128 判别器的特征数量
块数 int 3 判别器内的 ResNet 块数量
p float 0.15 判别器中的 dropout 量
返回值 Sequential

GANLoss

 GANLoss (gen_loss_func:Callable, crit_loss_func:Callable,
          gan_model:GANModule)

封装 crit_loss_funcgen_loss_func

类型 详情
生成器损失函数 Callable 生成器损失函数
判别器损失函数 Callable 判别器损失函数
gan_model GANModule GAN 模型

GANLoss.generator

 GANLoss.generator (output, target)

使用判别器评估 output,然后使用 self.gen_loss_func 评估判别器被 output 欺骗的程度

详情
输出 生成器输出
目标 真实图像

GANLoss.critic

 GANLoss.critic (real_pred, input)

使用生成器从 input 生成一些 fake_pred,并在 self.crit_loss_func 中将它们与 real_pred 进行比较。

详情
真实预测 判别器对真实图像的预测
输入 传递给生成器的输入噪声向量

如果调用了 generator 方法,此损失函数期望生成器的 output 和一些 target(一批真实图像)。它将使用 gen_loss_func 评估生成器是否成功欺骗了判别器。此损失函数具有以下签名:

def gen_loss_func(fake_pred, output, target):

以便能够将判别器对 output 的输出(即第一个参数 fake_pred)与 outputtarget 相结合(如果您想将 GAN 损失与其他损失混合)。

如果调用了 critic 方法,此损失函数期望判别器给出的 real_pred 和一些 input(输入给生成器的噪声)。它将使用 crit_loss_func 评估判别器。此损失函数具有以下签名:

def crit_loss_func(real_pred, fake_pred):

其中 real_pred 是判别器在一批真实图像上的输出,而 fake_pred 是使用生成器从噪声生成的。


AdaptiveLoss

 AdaptiveLoss (crit:Callable)

在应用 crit 之前扩展 target 以匹配 output 的大小。


accuracy_thresh_expand

 accuracy_thresh_expand (y_pred:torch.Tensor, y_true:torch.Tensor,
                         thresh:float=0.5, sigmoid:bool=True)

在将 y_true 扩展到 y_pred 的大小后计算阈值精度。

GAN 训练 Callbacks


set_freeze_model

 set_freeze_model (m:torch.nn.modules.module.Module, rg:bool)
类型 详情
模型 Module 要冻结/解冻的模型
rg bool Requires grad 参数。True 表示冻结

GANTrainer

 GANTrainer (switch_eval:bool=False, clip:None|float=None,
             beta:float=0.98, gen_first:bool=False, show_img:bool=True)

处理 GAN 训练的 Callback。

类型 默认值 详情
切换评估模式 bool False 计算损失时模型是否应设置为评估模式
裁剪 None | float None 权重的裁剪量
beta float 0.98 损失的指数加权平滑 beta
生成器优先 bool False 是否先开始生成器训练
显示图像 bool True 训练期间是否显示生成的示例图像
警告

GANTrainer 本身是无用的,需要与以下其中一个 switcher 配合使用


FixedGANSwitcher

 FixedGANSwitcher (n_crit:int=1, n_gen:int=1)

Switcher,先进行 n_crit 次判别器迭代,然后进行 n_gen 次生成器迭代。

类型 默认值 详情
判别器步数 int 1 切换到生成器训练前判别器训练的步数
生成器步数 int 1 切换到判别器训练前生成器训练的步数

AdaptiveGANSwitcher

 AdaptiveGANSwitcher (gen_thresh:None|float=None,
                      critic_thresh:None|float=None)

Switcher,当损失低于 gen_thresh/crit_thresh 时,切换回生成器/判别器。

类型 默认值 详情
生成器阈值 None | float None 生成器的损失阈值
判别器阈值 None | float None 判别器的损失阈值

GANDiscriminativeLR

 GANDiscriminativeLR (mult_lr=5.0)

Callback,处理判别器学习率乘以 mult_lr

GAN 数据


InvisibleTensor

 InvisibleTensor (x, **kwargs)

TensorBase,但 show 方法不执行任何操作


generate_noise

 generate_noise (fn, size=100)

生成噪声向量。

类型 默认值 详情
fn 虚拟参数,使其与 DataBlock 配合使用
尺寸 int 100 返回的噪声向量大小
返回值 InvisibleTensor

我们使用 generate_noise 函数生成噪声向量,将其传递给生成器进行图像生成。

bs = 128
size = 64
dblock = DataBlock(blocks = (TransformBlock, ImageBlock),
                   get_x = generate_noise,
                   get_items = get_image_files,
                   splitter = IndexSplitter([]),
                   item_tfms=Resize(size, method=ResizeMethod.Crop), 
                   batch_tfms = Normalize.from_stats(torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])))
path = untar_data(URLs.LSUN_BEDROOMS)
dls = dblock.dataloaders(path, path=path, bs=bs)
dls.show_batch(max_n=16)

GAN Learner


gan_loss_from_func

 gan_loss_from_func (loss_gen:Callable, loss_crit:Callable,
                     weights_gen:None|collections.abc.MutableSequence|tupl
                     e=None)

loss_genloss_crit 定义 GAN 的损失函数。

类型 默认值 详情
loss_gen Callable 生成器的损失函数。评估生成器输出图像和目标真实图像
loss_crit Callable 判别器的损失函数。评估真实和假图像的预测。
weights_gen None | collections.abc.MutableSequence | tuple None 生成器和判别器损失函数的权重

GANLearner

 GANLearner (dls:DataLoaders, generator:nn.Module, critic:nn.Module,
             gen_loss_func:Callable, crit_loss_func:Callable,
             switcher:Callback|None=None, gen_first:bool=False,
             switch_eval:bool=True, show_img:bool=True,
             clip:None|float=None, cbs:Callback|None|MutableSequence=None,
             metrics:None|MutableSequence|Callable=None,
             loss_func:Callable|None=None,
             opt_func:Optimizer|OptimWrapper=<function Adam>,
             lr:float|slice=0.001, splitter:Callable=<function
             trainable_params>, path:str|Path|None=None,
             model_dir:str|Path='models', wd:float|int|None=None,
             wd_bn_bias:bool=False, train_bn:bool=True, moms:tuple=(0.95,
             0.85, 0.95), default_cbs:bool=True)

适用于 GAN 的 Learner

类型 默认值 详情
dls DataLoaders GAN 数据的 DataLoaders 对象
生成器 nn.Module 生成器模型
判别器 nn.Module 判别器模型
生成器损失函数 Callable 生成器损失函数
判别器损失函数 Callable 判别器损失函数
切换器 Callback | None None 用于在生成器和判别器训练之间切换的 Callback,默认为 FixedGANSwitcher
生成器优先 bool False 是否先开始生成器训练
切换评估模式 bool True 计算损失时模型是否应设置为评估模式
显示图像 bool True 训练期间是否显示生成的示例图像
裁剪 None | float None 权重的裁剪量
cbs Callback | None | MutableSequence None 额外的 Callbacks
指标 None | MutableSequence | Callable None 指标
损失函数 Callable | None None
opt_func Optimizer | OptimWrapper Adam
学习率 float | slice 0.001
splitter Callable 可训练参数
路径 str | Path | None None
模型目录 str | Path 模型
权重衰减 float | int | None None
wd_bn_bias bool False
训练 bn bool True
动量 tuple (0.95, 0.85, 0.95)
default_cbs bool True

GANLearner.from_learners

 GANLearner.from_learners (gen_learn:Learner, crit_learn:Learner,
                           switcher:Callback|None=None,
                           weights_gen:None|MutableSequence|tuple=None,
                           gen_first:bool=False, switch_eval:bool=True,
                           show_img:bool=True, clip:None|float=None,
                           cbs:Callback|None|MutableSequence=None,
                           metrics:None|MutableSequence|Callable=None,
                           loss_func:Callable|None=None,
                           opt_func:Optimizer|OptimWrapper=<function
                           Adam>, lr:float|slice=0.001,
                           splitter:Callable=<function trainable_params>,
                           path:str|Path|None=None,
                           model_dir:str|Path='models',
                           wd:float|int|None=None, wd_bn_bias:bool=False,
                           train_bn:bool=True, moms:tuple=(0.95, 0.85,
                           0.95), default_cbs:bool=True)

learn_genlearn_crit 创建 GAN。

类型 默认值 详情
生成器 Learner Learner 包含生成器的 Learner 对象
判别器 Learner Learner 包含判别器的 Learner 对象
切换器 Callback | None None 用于在生成器和判别器训练之间切换的 Callback,默认为 FixedGANSwitcher
weights_gen None | MutableSequence | tuple None 生成器和判别器损失函数的权重
生成器优先 bool False 是否先开始生成器训练
切换评估模式 bool True 计算损失时模型是否应设置为评估模式
显示图像 bool True 训练期间是否显示生成的示例图像
裁剪 None | float None 权重的裁剪量
cbs Callback | None | MutableSequence None 额外的 Callbacks
指标 None | MutableSequence | Callable None 指标
损失函数 可选 None 损失函数。默认为 dls 的损失函数
opt_func fastai.optimizer.Optimizer | fastai.optimizer.OptimWrapper Adam 训练的优化函数
学习率 float | slice 0.001 默认学习率
splitter Callable 可训练参数 将模型分割成参数组。默认为一个参数组
路径 str | pathlib.Path | None None 保存、加载和导出模型的父目录。默认为 dlspath
模型目录 str | pathlib.Path 模型 保存和加载模型的子目录
权重衰减 float | int | None None 默认权重衰减
wd_bn_bias bool False 对归一化和偏差参数应用权重衰减
训练 bn bool True 训练冻结的归一化层
动量 tuple (0.95, 0.85, 0.95) 调度器的默认动量
default_cbs bool True 包含默认的 Callbacks

GANLearner.wgan

 GANLearner.wgan (dls:DataLoaders, generator:nn.Module, critic:nn.Module,
                  switcher:Callback|None=None, clip:None|float=0.01,
                  switch_eval:bool=False, gen_first:bool=False,
                  show_img:bool=True,
                  cbs:Callback|None|MutableSequence=None,
                  metrics:None|MutableSequence|Callable=None,
                  loss_func:Callable|None=None,
                  opt_func:Optimizer|OptimWrapper=<function Adam>,
                  lr:float|slice=0.001, splitter:Callable=<function
                  trainable_params>, path:str|Path|None=None,
                  model_dir:str|Path='models', wd:float|int|None=None,
                  wd_bn_bias:bool=False, train_bn:bool=True,
                  moms:tuple=(0.95, 0.85, 0.95), default_cbs:bool=True)

dlsgeneratorcritic 创建 WGAN

类型 默认值 详情
dls DataLoaders GAN 数据的 DataLoaders 对象
生成器 nn.Module 生成器模型
判别器 nn.Module 判别器模型
切换器 Callback | None None 用于在生成器和判别器训练之间切换的 Callback,默认为 FixedGANSwitcher(n_crit=5, n_gen=1)
裁剪 None | float 0.01 权重的裁剪量
切换评估模式 bool False 计算损失时模型是否应设置为评估模式
生成器优先 bool False 是否先开始生成器训练
显示图像 bool True 训练期间是否显示生成的示例图像
cbs Callback | None | MutableSequence None 额外的 Callbacks
指标 None | MutableSequence | Callable None 指标
损失函数 可选 None 损失函数。默认为 dls 的损失函数
opt_func fastai.optimizer.Optimizer | fastai.optimizer.OptimWrapper Adam 训练的优化函数
学习率 float | slice 0.001 默认学习率
splitter Callable 可训练参数 将模型分割成参数组。默认为一个参数组
路径 str | pathlib.Path | None None 保存、加载和导出模型的父目录。默认为 dlspath
模型目录 str | pathlib.Path 模型 保存和加载模型的子目录
权重衰减 float | int | None None 默认权重衰减
wd_bn_bias bool False 对归一化和偏差参数应用权重衰减
训练 bn bool True 训练冻结的归一化层
动量 tuple (0.95, 0.85, 0.95) 调度器的默认动量
default_cbs bool True 包含默认的 Callbacks
from fastai.callback.all import *
generator = basic_generator(64, n_channels=3, n_extra_layers=1)
critic    = basic_critic   (64, n_channels=3, n_extra_layers=1, act_cls=partial(nn.LeakyReLU, negative_slope=0.2))
learn = GANLearner.wgan(dls, generator, critic, opt_func = RMSProp)
learn.recorder.train_metrics=True
learn.recorder.valid_metrics=False
learn.fit(1, 2e-4, wd=0.)
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (generator) that exists in the learner. Use `self.learn.generator` to avoid this
  warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (critic) that exists in the learner. Use `self.learn.critic` to avoid this
  warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (gen_mode) that exists in the learner. Use `self.learn.gen_mode` to avoid this
  warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
周期 训练损失 生成器损失 判别器损失 时间
0 -0.815071 0.646809 -1.140522 00:38
/home/tmabraham/anaconda3/envs/fastai/lib/python3.7/site-packages/fastprogress/fastprogress.py:74: UserWarning: Your generator is empty.
  warn("Your generator is empty.")
learn.show_results(max_n=9, ds_idx=0)