Channels Last 训练
使用 MixedPrecision
,在 Tensor Cores 上以 Channels Last 格式训练的图像模型可以比 Contiguous 格式提高训练吞吐量。PyTorch 观察到,使用 Channels Last 格式训练 ResNet50 的速度提高了 22%,并在 V100 上测试的一系列模型中获得了 8-35% 的提升。
Channels Last 格式兼容现代 GPU (Volta, Turing 或更新型号) 和现代 CPU (Ice Lake 或更新型号)。
Channels Last 内存格式目前已针对 NCHW 张量实现。并非所有 PyTorch 运算符都已转换为支持 Channels Last。有关更多详情,请参阅 PyTorch 中的 Channels Last 内存格式(Beta 版)教程。
ChannelsLast
ChannelsLast (after_create=None, before_fit=None, before_epoch=None, before_train=None, before_batch=None, after_pred=None, after_loss=None, before_backward=None, after_cancel_backward=None, after_backward=None, before_step=None, after_cancel_step=None, after_step=None, after_cancel_batch=None, after_batch=None, after_cancel_train=None, after_train=None, before_validate=None, after_cancel_validate=None, after_validate=None, after_cancel_epoch=None, after_epoch=None, after_cancel_fit=None, after_fit=None)
使用 PyTorch 的 Channels Last 内存格式进行 Channels Last 训练(Beta 版)
当 PyTorch 模型设置为 Channels Last 格式时,PyTorch 会自动将任何兼容的 NCHW 输入张量转换为 NHWC 格式。ChannelsLast
将模型设置为 Channels Last 格式,因此无需更改 dataloader 或输入。
ChannelsLast
应适用于大多数卷积型的 timm
模型。
但是,建议对每个模型进行测试,因为不同 PyTorch 版本支持的操作有所不同。
在不受支持的 PyTorch 操作中使用 ChannelsLast 可能导致“通道抖动”(channel thrashing),即在不受支持的 PyTorch 操作中,Channels Last 输入被转换为 Contiguous 格式,然后转回 Channels Last 以在 Tensor Core 上执行,返回操作时再转回 Contiguous,最后再转换为 Channels Last 以供下一层使用。模型中过多的不受支持的操作会导致性能下降。
Learner.to_channelslast
Learner.to_channelslast (use_amp:bool=True, amp_mode:str|AMPMode=<AMPMode.FP16: 'fp16'>, init_scale:float=65536.0, growth_factor:float=2.0, backoff_factor:float=0.5, growth_interval:int=2000, enabled:bool=True)
默认将 Learner 和输入设置为 channels_last
格式和 float16 混合精度
类型 | 默认值 | 详情 | |
---|---|---|---|
use_amp | bool | True | 添加带 amp_mode 的 MixedPrecision 。建议用于充分发挥 Channels Last 性能 |
amp_mode | str | AMPMode | AMPMode.FP16 | 混合精度训练模式。支持 fp16 和 bf16。 |
init_scale | float | 65536.0 | |
growth_factor | float | 2.0 | |
backoff_factor | float | 0.5 | |
growth_interval | int | 2000 | |
enabled | bool | True |
Learner.to_contiguous
Learner.to_contiguous (to_fp32:bool=False)
将 Learner 和输入设置为 contiguous_format
(默认格式),可选地设置为单精度