= basic_critic(64, 3)
critic = basic_generator(64, 3)
generator = GANModule(critic=critic, generator=generator)
tst = torch.randn(2, 3, 64, 64)
real = tst(real)
real_p 2,1])
test_eq(real_p.shape, [
#tst is now in generator mode
tst.switch() = torch.randn(2, 100)
noise = tst(noise)
fake
test_eq(fake.shape, real.shape)
#tst is back in critic mode
tst.switch() = tst(fake)
fake_p 2,1]) test_eq(fake_p.shape, [
GAN
GAN 是指 生成对抗网络 (Generative Adversarial Nets),由 Ian Goodfellow 发明。其概念是同时训练两个模型:生成器和判别器。生成器试图创建与数据集中图像相似的新图像,而判别器则试图区分真实图像和生成器创建的图像。生成器返回图像,判别器返回一个单一数字(通常是概率,假图像为 0,真实图像为 1)。
我们让他们相互对抗,其训练过程大致如下:
- 冻结生成器,训练判别器一步:
- 获取一批真实图像(称之为
real
) - 生成一批假图像(称之为
fake
) - 让判别器评估每一批图像并计算损失函数;重要的一点是,它对检测到真实图像给予正面奖励,对检测到假图像进行惩罚
- 使用此损失的梯度更新判别器的权重
- 冻结判别器,训练生成器一步:
- 生成一批假图像
- 让判别器对其进行评估
- 返回一个损失,该损失对判别器认为这些是真实图像的情况给予正面奖励
- 使用此损失的梯度更新生成器的权重
fastai 库通过 GANTrainer 提供对 GAN 训练的支持,但不包含高级模型,只提供基本模型。
封装模块
GANModule
GANModule (generator:nn.Module=None, critic:nn.Module=None, gen_mode:None|bool=False)
封装 generator
和 critic
以创建 GAN。
类型 | 默认值 | 详情 | |
---|---|---|---|
生成器 | Module | None | 生成器 PyTorch 模块 |
判别器 | Module | None | 判别器 PyTorch 模块 |
生成模式 | None | bool | False | GAN 是否应设置为生成模式 |
这只是一个包含两个模型的壳。调用时,它将根据 gen_mode
的值将输入委托给 generator
或 critic
。
GANModule.switch
GANModule.switch (gen_mode:None|bool=None)
如果 gen_mode
为 True
,则将模块置于生成模式,否则置于判别模式。
类型 | 默认值 | 详情 | |
---|---|---|---|
生成模式 | None | bool | None | GAN 是否应设置为生成模式 |
默认情况下(将 gen_mode
设为 None
),这将把模块置于另一种模式(如果处于生成模式则切换到判别模式,反之亦然)。
basic_critic
basic_critic (in_size:int, n_channels:int, n_features:int=64, n_extra_layers:int=0, norm_type:NormType=<NormType.Batch: 1>, ks=3, stride=1, padding=None, bias=None, ndim=2, bn_1st=True, act_cls=<class 'torch.nn.modules.activation.ReLU'>, transpose=False, init='auto', xtra=None, bias_std=0.01, dilation:Union[int,Tuple[int,int]]=1, groups:int=1, padding_mode:str='zeros', device=None, dtype=None)
一个用于 n_channels
x in_size
x in_size
图像的基本判别器。
类型 | 默认值 | 详情 | |
---|---|---|---|
输入尺寸 | int | 判别器的输入尺寸(与生成器的输出尺寸相同) | |
通道数 | int | 判别器输入的通道数 | |
特征数 | int | 64 | 判别器中使用的特征数量 |
额外层数 | int | 0 | 判别器中额外隐藏层的数量 |
归一化类型 | NormType | NormType.Batch | 判别器中使用的归一化类型 |
核大小 | int | 3 | |
步长 | int | 1 | |
填充 | NoneType | None | |
偏差 | NoneType | None | |
维度 | int | 2 | |
bn_1st | bool | True | |
激活函数类型 | type | ReLU | |
转置 | bool | False | |
初始化 | str | auto | |
额外参数 | NoneType | None | |
偏差标准差 | float | 0.01 | |
膨胀 | Union | 1 | |
分组 | int | 1 | |
填充模式 | str | zeros | TODO: 完善此类型 |
设备 | NoneType | None | |
数据类型 | NoneType | None | |
返回值 | nn.Sequential |
AddChannels
AddChannels (n_dim)
在输入的末尾添加 n_dim
个通道。
basic_generator
basic_generator (out_size:int, n_channels:int, in_sz:int=100, n_features:int=64, n_extra_layers:int=0, ks=3, stride=1, padding=None, bias=None, ndim=2, norm_type=<NormType.Batch: 1>, bn_1st=True, act_cls=<class 'torch.nn.modules.activation.ReLU'>, transpose=False, init='auto', xtra=None, bias_std=0.01, dilation:Union[int,Tuple[int,int]]=1, groups:int=1, padding_mode:str='zeros', device=None, dtype=None)
一个从 in_sz
到 n_channels
x out_size
x out_size
图像的基本生成器。
类型 | 默认值 | 详情 | |
---|---|---|---|
输出尺寸 | int | 生成器的输出尺寸(与判别器的输入尺寸相同) | |
通道数 | int | 生成器输出的通道数 | |
输入尺寸 | int | 100 | 生成器的输入噪声向量大小 |
特征数 | int | 64 | 生成器中使用的特征数量 |
额外层数 | int | 0 | 生成器中额外隐藏层的数量 |
核大小 | int | 3 | |
步长 | int | 1 | |
填充 | NoneType | None | |
偏差 | NoneType | None | |
维度 | int | 2 | |
归一化类型 | NormType | NormType.Batch | |
bn_1st | bool | True | |
激活函数类型 | type | ReLU | |
转置 | bool | False | |
初始化 | str | auto | |
额外参数 | NoneType | None | |
偏差标准差 | float | 0.01 | |
膨胀 | Union | 1 | |
分组 | int | 1 | |
填充模式 | str | zeros | TODO: 完善此类型 |
设备 | NoneType | None | |
数据类型 | NoneType | None | |
返回值 | nn.Sequential |
DenseResBlock
DenseResBlock (nf:int, norm_type:NormType=<NormType.Batch: 1>, ks=3, stride=1, padding=None, bias=None, ndim=2, bn_1st=True, act_cls=<class 'torch.nn.modules.activation.ReLU'>, transpose=False, init='auto', xtra=None, bias_std=0.01, dilation:Union[int,Tuple[int,int]]=1, groups:int=1, padding_mode:str='zeros', device=None, dtype=None)
具有 nf
特征的 Resnet 块。conv_kwargs
传递给 conv_layer
。
类型 | 默认值 | 详情 | |
---|---|---|---|
特征数 | int | 特征的数量 | |
归一化类型 | NormType | NormType.Batch | 归一化类型 |
核大小 | int | 3 | |
步长 | int | 1 | |
填充 | NoneType | None | |
偏差 | NoneType | None | |
维度 | int | 2 | |
bn_1st | bool | True | |
激活函数类型 | type | ReLU | |
转置 | bool | False | |
初始化 | str | auto | |
额外参数 | NoneType | None | |
偏差标准差 | float | 0.01 | |
膨胀 | Union | 1 | |
分组 | int | 1 | |
填充模式 | str | zeros | TODO: 完善此类型 |
设备 | NoneType | None | |
数据类型 | NoneType | None | |
返回值 | SequentialEx |
gan_critic
gan_critic (n_channels:int=3, nf:int=128, n_blocks:int=3, p:float=0.15)
用于训练 GAN
的判别器。
类型 | 默认值 | 详情 | |
---|---|---|---|
通道数 | int | 3 | 判别器输入的通道数 |
特征数 | int | 128 | 判别器的特征数量 |
块数 | int | 3 | 判别器内的 ResNet 块数量 |
p | float | 0.15 | 判别器中的 dropout 量 |
返回值 | Sequential |
GANLoss
GANLoss (gen_loss_func:Callable, crit_loss_func:Callable, gan_model:GANModule)
封装 crit_loss_func
和 gen_loss_func
类型 | 详情 | |
---|---|---|
生成器损失函数 | Callable | 生成器损失函数 |
判别器损失函数 | Callable | 判别器损失函数 |
gan_model | GANModule | GAN 模型 |
GANLoss.generator
GANLoss.generator (output, target)
使用判别器评估 output
,然后使用 self.gen_loss_func
评估判别器被 output
欺骗的程度
详情 | |
---|---|
输出 | 生成器输出 |
目标 | 真实图像 |
GANLoss.critic
GANLoss.critic (real_pred, input)
使用生成器从 input
生成一些 fake_pred
,并在 self.crit_loss_func
中将它们与 real_pred
进行比较。
详情 | |
---|---|
真实预测 | 判别器对真实图像的预测 |
输入 | 传递给生成器的输入噪声向量 |
如果调用了 generator
方法,此损失函数期望生成器的 output
和一些 target
(一批真实图像)。它将使用 gen_loss_func
评估生成器是否成功欺骗了判别器。此损失函数具有以下签名:
def gen_loss_func(fake_pred, output, target):
以便能够将判别器对 output
的输出(即第一个参数 fake_pred
)与 output
和 target
相结合(如果您想将 GAN 损失与其他损失混合)。
如果调用了 critic
方法,此损失函数期望判别器给出的 real_pred
和一些 input
(输入给生成器的噪声)。它将使用 crit_loss_func
评估判别器。此损失函数具有以下签名:
def crit_loss_func(real_pred, fake_pred):
其中 real_pred
是判别器在一批真实图像上的输出,而 fake_pred
是使用生成器从噪声生成的。
AdaptiveLoss
AdaptiveLoss (crit:Callable)
在应用 crit
之前扩展 target
以匹配 output
的大小。
accuracy_thresh_expand
accuracy_thresh_expand (y_pred:torch.Tensor, y_true:torch.Tensor, thresh:float=0.5, sigmoid:bool=True)
在将 y_true
扩展到 y_pred
的大小后计算阈值精度。
GAN 训练 Callbacks
set_freeze_model
set_freeze_model (m:torch.nn.modules.module.Module, rg:bool)
类型 | 详情 | |
---|---|---|
模型 | Module | 要冻结/解冻的模型 |
rg | bool | Requires grad 参数。True 表示冻结 |
GANTrainer
GANTrainer (switch_eval:bool=False, clip:None|float=None, beta:float=0.98, gen_first:bool=False, show_img:bool=True)
处理 GAN 训练的 Callback。
类型 | 默认值 | 详情 | |
---|---|---|---|
切换评估模式 | bool | False | 计算损失时模型是否应设置为评估模式 |
裁剪 | None | float | None | 权重的裁剪量 |
beta | float | 0.98 | 损失的指数加权平滑 beta |
生成器优先 | bool | False | 是否先开始生成器训练 |
显示图像 | bool | True | 训练期间是否显示生成的示例图像 |
GANTrainer 本身是无用的,需要与以下其中一个 switcher 配合使用
FixedGANSwitcher
FixedGANSwitcher (n_crit:int=1, n_gen:int=1)
Switcher,先进行 n_crit
次判别器迭代,然后进行 n_gen
次生成器迭代。
类型 | 默认值 | 详情 | |
---|---|---|---|
判别器步数 | int | 1 | 切换到生成器训练前判别器训练的步数 |
生成器步数 | int | 1 | 切换到判别器训练前生成器训练的步数 |
AdaptiveGANSwitcher
AdaptiveGANSwitcher (gen_thresh:None|float=None, critic_thresh:None|float=None)
Switcher,当损失低于 gen_thresh
/crit_thresh
时,切换回生成器/判别器。
类型 | 默认值 | 详情 | |
---|---|---|---|
生成器阈值 | None | float | None | 生成器的损失阈值 |
判别器阈值 | None | float | None | 判别器的损失阈值 |
GANDiscriminativeLR
GANDiscriminativeLR (mult_lr=5.0)
Callback
,处理判别器学习率乘以 mult_lr
。
GAN 数据
InvisibleTensor
InvisibleTensor (x, **kwargs)
TensorBase,但 show 方法不执行任何操作
generate_noise
generate_noise (fn, size=100)
生成噪声向量。
类型 | 默认值 | 详情 | |
---|---|---|---|
fn | 虚拟参数,使其与 DataBlock 配合使用 |
||
尺寸 | int | 100 | 返回的噪声向量大小 |
返回值 | InvisibleTensor |
我们使用 generate_noise
函数生成噪声向量,将其传递给生成器进行图像生成。
= 128
bs = 64 size
= DataBlock(blocks = (TransformBlock, ImageBlock),
dblock = generate_noise,
get_x = get_image_files,
get_items = IndexSplitter([]),
splitter =Resize(size, method=ResizeMethod.Crop),
item_tfms= Normalize.from_stats(torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5]))) batch_tfms
= untar_data(URLs.LSUN_BEDROOMS) path
= dblock.dataloaders(path, path=path, bs=bs) dls
=16) dls.show_batch(max_n
GAN Learner
gan_loss_from_func
gan_loss_from_func (loss_gen:Callable, loss_crit:Callable, weights_gen:None|collections.abc.MutableSequence|tupl e=None)
从 loss_gen
和 loss_crit
定义 GAN 的损失函数。
类型 | 默认值 | 详情 | |
---|---|---|---|
loss_gen | Callable | 生成器的损失函数。评估生成器输出图像和目标真实图像 | |
loss_crit | Callable | 判别器的损失函数。评估真实和假图像的预测。 | |
weights_gen | None | collections.abc.MutableSequence | tuple | None | 生成器和判别器损失函数的权重 |
GANLearner
GANLearner (dls:DataLoaders, generator:nn.Module, critic:nn.Module, gen_loss_func:Callable, crit_loss_func:Callable, switcher:Callback|None=None, gen_first:bool=False, switch_eval:bool=True, show_img:bool=True, clip:None|float=None, cbs:Callback|None|MutableSequence=None, metrics:None|MutableSequence|Callable=None, loss_func:Callable|None=None, opt_func:Optimizer|OptimWrapper=<function Adam>, lr:float|slice=0.001, splitter:Callable=<function trainable_params>, path:str|Path|None=None, model_dir:str|Path='models', wd:float|int|None=None, wd_bn_bias:bool=False, train_bn:bool=True, moms:tuple=(0.95, 0.85, 0.95), default_cbs:bool=True)
适用于 GAN 的 Learner
。
类型 | 默认值 | 详情 | |
---|---|---|---|
dls | DataLoaders | GAN 数据的 DataLoaders 对象 | |
生成器 | nn.Module | 生成器模型 | |
判别器 | nn.Module | 判别器模型 | |
生成器损失函数 | Callable | 生成器损失函数 | |
判别器损失函数 | Callable | 判别器损失函数 | |
切换器 | Callback | None | None | 用于在生成器和判别器训练之间切换的 Callback,默认为 FixedGANSwitcher |
生成器优先 | bool | False | 是否先开始生成器训练 |
切换评估模式 | bool | True | 计算损失时模型是否应设置为评估模式 |
显示图像 | bool | True | 训练期间是否显示生成的示例图像 |
裁剪 | None | float | None | 权重的裁剪量 |
cbs | Callback | None | MutableSequence | None | 额外的 Callbacks |
指标 | None | MutableSequence | Callable | None | 指标 |
损失函数 | Callable | None | None | |
opt_func | Optimizer | OptimWrapper | Adam | |
学习率 | float | slice | 0.001 | |
splitter | Callable | 可训练参数 | |
路径 | str | Path | None | None | |
模型目录 | str | Path | 模型 | |
权重衰减 | float | int | None | None | |
wd_bn_bias | bool | False | |
训练 bn | bool | True | |
动量 | tuple | (0.95, 0.85, 0.95) | |
default_cbs | bool | True |
GANLearner.from_learners
GANLearner.from_learners (gen_learn:Learner, crit_learn:Learner, switcher:Callback|None=None, weights_gen:None|MutableSequence|tuple=None, gen_first:bool=False, switch_eval:bool=True, show_img:bool=True, clip:None|float=None, cbs:Callback|None|MutableSequence=None, metrics:None|MutableSequence|Callable=None, loss_func:Callable|None=None, opt_func:Optimizer|OptimWrapper=<function Adam>, lr:float|slice=0.001, splitter:Callable=<function trainable_params>, path:str|Path|None=None, model_dir:str|Path='models', wd:float|int|None=None, wd_bn_bias:bool=False, train_bn:bool=True, moms:tuple=(0.95, 0.85, 0.95), default_cbs:bool=True)
从 learn_gen
和 learn_crit
创建 GAN。
类型 | 默认值 | 详情 | |
---|---|---|---|
生成器 Learner | Learner | 包含生成器的 Learner 对象 |
|
判别器 Learner | Learner | 包含判别器的 Learner 对象 |
|
切换器 | Callback | None | None | 用于在生成器和判别器训练之间切换的 Callback,默认为 FixedGANSwitcher |
weights_gen | None | MutableSequence | tuple | None | 生成器和判别器损失函数的权重 |
生成器优先 | bool | False | 是否先开始生成器训练 |
切换评估模式 | bool | True | 计算损失时模型是否应设置为评估模式 |
显示图像 | bool | True | 训练期间是否显示生成的示例图像 |
裁剪 | None | float | None | 权重的裁剪量 |
cbs | Callback | None | MutableSequence | None | 额外的 Callbacks |
指标 | None | MutableSequence | Callable | None | 指标 |
损失函数 | 可选 | None | 损失函数。默认为 dls 的损失函数 |
opt_func | fastai.optimizer.Optimizer | fastai.optimizer.OptimWrapper | Adam | 训练的优化函数 |
学习率 | float | slice | 0.001 | 默认学习率 |
splitter | Callable | 可训练参数 | 将模型分割成参数组。默认为一个参数组 |
路径 | str | pathlib.Path | None | None | 保存、加载和导出模型的父目录。默认为 dls 的 path |
模型目录 | str | pathlib.Path | 模型 | 保存和加载模型的子目录 |
权重衰减 | float | int | None | None | 默认权重衰减 |
wd_bn_bias | bool | False | 对归一化和偏差参数应用权重衰减 |
训练 bn | bool | True | 训练冻结的归一化层 |
动量 | tuple | (0.95, 0.85, 0.95) | 调度器的默认动量 |
default_cbs | bool | True | 包含默认的 Callback s |
GANLearner.wgan
GANLearner.wgan (dls:DataLoaders, generator:nn.Module, critic:nn.Module, switcher:Callback|None=None, clip:None|float=0.01, switch_eval:bool=False, gen_first:bool=False, show_img:bool=True, cbs:Callback|None|MutableSequence=None, metrics:None|MutableSequence|Callable=None, loss_func:Callable|None=None, opt_func:Optimizer|OptimWrapper=<function Adam>, lr:float|slice=0.001, splitter:Callable=<function trainable_params>, path:str|Path|None=None, model_dir:str|Path='models', wd:float|int|None=None, wd_bn_bias:bool=False, train_bn:bool=True, moms:tuple=(0.95, 0.85, 0.95), default_cbs:bool=True)
从 dls
、generator
和 critic
创建 WGAN。
类型 | 默认值 | 详情 | |
---|---|---|---|
dls | DataLoaders | GAN 数据的 DataLoaders 对象 | |
生成器 | nn.Module | 生成器模型 | |
判别器 | nn.Module | 判别器模型 | |
切换器 | Callback | None | None | 用于在生成器和判别器训练之间切换的 Callback,默认为 FixedGANSwitcher(n_crit=5, n_gen=1) |
裁剪 | None | float | 0.01 | 权重的裁剪量 |
切换评估模式 | bool | False | 计算损失时模型是否应设置为评估模式 |
生成器优先 | bool | False | 是否先开始生成器训练 |
显示图像 | bool | True | 训练期间是否显示生成的示例图像 |
cbs | Callback | None | MutableSequence | None | 额外的 Callbacks |
指标 | None | MutableSequence | Callable | None | 指标 |
损失函数 | 可选 | None | 损失函数。默认为 dls 的损失函数 |
opt_func | fastai.optimizer.Optimizer | fastai.optimizer.OptimWrapper | Adam | 训练的优化函数 |
学习率 | float | slice | 0.001 | 默认学习率 |
splitter | Callable | 可训练参数 | 将模型分割成参数组。默认为一个参数组 |
路径 | str | pathlib.Path | None | None | 保存、加载和导出模型的父目录。默认为 dls 的 path |
模型目录 | str | pathlib.Path | 模型 | 保存和加载模型的子目录 |
权重衰减 | float | int | None | None | 默认权重衰减 |
wd_bn_bias | bool | False | 对归一化和偏差参数应用权重衰减 |
训练 bn | bool | True | 训练冻结的归一化层 |
动量 | tuple | (0.95, 0.85, 0.95) | 调度器的默认动量 |
default_cbs | bool | True | 包含默认的 Callback s |
from fastai.callback.all import *
= basic_generator(64, n_channels=3, n_extra_layers=1)
generator = basic_critic (64, n_channels=3, n_extra_layers=1, act_cls=partial(nn.LeakyReLU, negative_slope=0.2)) critic
= GANLearner.wgan(dls, generator, critic, opt_func = RMSProp) learn
=True
learn.recorder.train_metrics=False learn.recorder.valid_metrics
1, 2e-4, wd=0.) learn.fit(
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (generator) that exists in the learner. Use `self.learn.generator` to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (critic) that exists in the learner. Use `self.learn.critic` to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (gen_mode) that exists in the learner. Use `self.learn.gen_mode` to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
周期 | 训练损失 | 生成器损失 | 判别器损失 | 时间 |
---|---|---|---|---|
0 | -0.815071 | 0.646809 | -1.140522 | 00:38 |
/home/tmabraham/anaconda3/envs/fastai/lib/python3.7/site-packages/fastprogress/fastprogress.py:74: UserWarning: Your generator is empty.
warn("Your generator is empty.")
=9, ds_idx=0) learn.show_results(max_n