回调

Learner 的基本回调

事件

回调可以在以下任何时间发生: after_create before_fit before_epoch before_train before_batch after_pred after_loss before_backward after_cancel_backward after_backward before_step after_cancel_step after_step after_cancel_batch after_batch after_cancel_train after_train before_validate after_cancel_validate after_validate after_cancel_epoch after_epoch after_cancel_fit after_fit


event

 event (*args, **kwargs)

所有可能的事件作为属性,用于标签补全和防错别字

为确保您引用的是一个存在的事件(即回调被调用的某个时间点的名称),并获得事件名称的标签补全,请使用 event

test_eq(event.before_step, 'before_step')

源码

Callback

 Callback (after_create=None, before_fit=None, before_epoch=None,
           before_train=None, before_batch=None, after_pred=None,
           after_loss=None, before_backward=None,
           after_cancel_backward=None, after_backward=None,
           before_step=None, after_cancel_step=None, after_step=None,
           after_cancel_batch=None, after_batch=None,
           after_cancel_train=None, after_train=None,
           before_validate=None, after_cancel_validate=None,
           after_validate=None, after_cancel_epoch=None, after_epoch=None,
           after_cancel_fit=None, after_fit=None)

通过在各种事件中更改 Learner 来处理训练循环调整的基本类

训练循环在下方的 Learner 中定义,包含一组最小指令:遍历数据,我们

  • 计算模型根据输入的输出
  • 计算此输出与期望目标之间的损失
  • 计算此损失相对于所有模型参数的梯度
  • 相应地更新参数
  • 将所有梯度归零

对此训练循环的任何调整都在 Callback 中定义,以避免训练循环的代码过于复杂,并使其易于混合和匹配不同的技术(因为它们将在不同的回调中定义)。回调可以在以下事件上实现操作

  • after_create:在 Learner 创建后调用
  • before_fit:在开始训练或推理前调用,非常适合初始设置。
  • before_epoch:在每个 epoch 开始时调用,用于在每个 epoch 需要重置的任何行为。
  • before_train:在 epoch 的训练阶段开始时调用。
  • before_batch:在每个批次开始时调用,就在获取该批次之后。可用于进行该批次的任何必要设置(如超参数调度),或在输入/目标进入模型之前更改它们(例如使用 mixup 等技术更改输入)。
  • after_pred:在计算模型在该批次上的输出后调用。可用于在输出馈送到损失函数之前更改它。
  • after_loss:在计算损失后调用,但在反向传播之前。可用于向损失添加任何惩罚项(例如 RNN 训练中的 AR 或 TAR)。
  • before_backward:在计算损失后调用,但仅在训练模式下(即使用反向传播时)
  • after_backward:在反向传播后调用,但在参数更新之前。通常应该改用 before_step
  • before_step:在反向传播后调用,但在参数更新之前。可用于在所述更新之前对梯度进行任何更改(例如梯度裁剪)。
  • after_step:在步骤执行后且梯度归零之前调用。
  • after_batch:在一个批次结束时调用,用于在下一个批次之前进行任何清理工作。
  • after_train:在 epoch 的训练阶段结束时调用。
  • before_validate:在 epoch 的验证阶段开始时调用,用于专门为验证所需的任何设置。
  • after_validate:在 epoch 的验证阶段结束时调用。
  • after_epoch:在一个 epoch 结束时调用,用于在下一个 epoch 之前进行任何清理工作。
  • after_fit:在训练结束时调用,用于最终清理工作。

源码

Callback.__call__

 Callback.__call__ (event_name)

如果定义了 self.{event_name},则调用它

定义回调的一种方法是通过子类化

class _T(Callback):
    def call_me(self): return "maybe"
test_eq(_T()("call_me"), "maybe")

另一种方法是将回调函数传递给构造函数

def cb(self): return "maybe"
_t = Callback(before_fit=cb)
test_eq(_t(event.before_fit), "maybe")

Callback 提供了快捷方式,避免了对于任何我们查找的 bla 属性,都必须写 self.learn.bla;相反,只需写 self.bla。这仅适用于获取属性,适用于设置属性。

mk_class('TstLearner', 'a')

class TstCallback(Callback):
    def batch_begin(self): print(self.a)

learn,cb = TstLearner(1),TstCallback()
cb.learn = learn
test_stdout(lambda: cb('batch_begin'), "1")

如果要更改属性的值,必须使用 self.learn.bla,而不是 self.bla。在下面的示例中,self.a += 1 在回调中创建了一个值为 2 的 a 属性,而不是将 learner 的 a 设置为 2。它还会发出警告,提示可能出了问题

learn.a
1
class TstCallback(Callback):
    def batch_begin(self): self.a += 1

learn,cb = TstLearner(1),TstCallback()
cb.learn = learn
cb('batch_begin')
test_eq(cb.a, 2)
test_eq(cb.learn.a, 1)
/tmp/ipykernel_5201/1369389649.py:29: UserWarning: You are shadowing an attribute (a) that exists in the learner. Use `self.learn.a` to avoid this
  warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")

正确版本需要写 self.learn.a = self.a + 1

class TstCallback(Callback):
    def batch_begin(self): self.learn.a = self.a + 1

learn,cb = TstLearner(1),TstCallback()
cb.learn = learn
cb('batch_begin')
test_eq(cb.learn.a, 2)

源码

Callback.name

 Callback.name ()

Callback 的名称,使用驼峰命名法,并去掉 ‘Callback

test_eq(TstCallback().name, 'tst')
class ComplicatedNameCallback(Callback): pass
test_eq(ComplicatedNameCallback().name, 'complicated_name')

源码

TrainEvalCallback

 TrainEvalCallback (after_create=None, before_fit=None, before_epoch=None,
                    before_train=None, before_batch=None, after_pred=None,
                    after_loss=None, before_backward=None,
                    after_cancel_backward=None, after_backward=None,
                    before_step=None, after_cancel_step=None,
                    after_step=None, after_cancel_batch=None,
                    after_batch=None, after_cancel_train=None,
                    after_train=None, before_validate=None,
                    after_cancel_validate=None, after_validate=None,
                    after_cancel_epoch=None, after_epoch=None,
                    after_cancel_fit=None, after_fit=None)

跟踪已完成的迭代次数并正确设置训练/评估模式的 Callback

Callback 在每个 Learner 初始化时自动添加。

回调可用的属性

编写回调时,可以使用 Learner 的以下属性

  • model:用于训练/验证的模型
  • dls:底层 DataLoaders
  • loss_func:使用的损失函数
  • opt:用于更新模型参数的优化器
  • opt_func:用于创建优化器的函数
  • cbs:包含所有 Callback 的列表
  • dl:当前用于迭代的 DataLoader
  • x/xb:从 self.dl 中提取的最后一个输入(可能被回调修改)。xb 始终是一个元组(可能包含一个元素),而 x 是解元组后的结果。您只能赋值给 xb
  • y/yb:从 self.dl 中提取的最后一个目标(可能被回调修改)。yb 始终是一个元组(可能包含一个元素),而 y 是解元组后的结果。您只能赋值给 yb
  • pred:从 self.model 获得的最后一个预测结果(可能被回调修改)
  • loss_grad:最后一个计算出的损失(可能被回调修改)
  • lossloss_grad 的克隆,用于日志记录
  • n_epoch:本次训练的总 epoch 数
  • n_iter:当前 self.dl 中的迭代次数
  • epoch:当前 epoch 索引(从 0 到 n_epoch-1
  • iter:当前 self.dl 中的迭代索引(从 0 到 n_iter-1

以下属性由 TrainEvalCallback 添加,除非您特意将其移除,否则应该可用

  • train_iter:自本次训练开始以来完成的训练迭代次数
  • pct_train:从 0. 到 1.,已完成训练迭代的百分比
  • training:指示当前是否处于训练模式的标志

以下属性由 Recorder 添加,除非您特意将其移除,否则应该可用

  • smooth_loss:训练损失的指数平均版本

回调控制流程

有时我们可能想跳过训练循环中的某些步骤:例如,在梯度累积中,我们不总是想执行步进/梯度归零操作。在 LR finder 测试期间,我们不想执行 epoch 的验证阶段。或者,如果我们使用早停策略进行训练,我们希望能够完全中断训练循环。

通过引发训练循环将寻找(并正确捕获)的特定异常,这成为可能。


CancelStepException

 CancelStepException (*args, **kwargs)

跳过优化器的步进


CancelBatchException

 CancelBatchException (*args, **kwargs)

跳过当前批次的其余部分,直接进入 after_batch


CancelBackwardException

 CancelBackwardException (*args, **kwargs)

跳过反向传播,直接进入 after_backward


CancelTrainException

 CancelTrainException (*args, **kwargs)

跳过 epoch 训练部分的其余部分,直接进入 after_train


CancelValidException

 CancelValidException (*args, **kwargs)

跳过 epoch 验证部分的其余部分,直接进入 after_validate


CancelEpochException

 CancelEpochException (*args, **kwargs)

跳过当前 epoch 的其余部分,直接进入 after_epoch


CancelFitException

 CancelFitException (*args, **kwargs)

中断训练,直接进入 after_fit

您可以通过以下事件检测到其中一个异常的发生,并在其后立即添加执行的代码

  • after_cancel_batch:在发生 CancelBatchException 后立即到达,之后进入 after_batch
  • after_cancel_train:在发生 CancelTrainException 后立即到达,之后进入 after_epoch
  • after_cancel_valid:在发生 CancelValidException 后立即到达,之后进入 after_epoch
  • after_cancel_epoch:在发生 CancelEpochException 后立即到达,之后进入 after_epoch
  • after_cancel_fit:在发生 CancelFitException 后立即到达,之后进入 after_fit

源码

GatherPredsCallback

 GatherPredsCallback (with_input:bool=False, with_loss:bool=False,
                      save_preds:pathlib.Path=None,
                      save_targs:pathlib.Path=None, with_preds:bool=True,
                      with_targs:bool=True, concat_dim:int=0,
                      pickle_protocol:int=2)

返回所有预测和目标,可选包含 with_inputwith_lossCallback

类型 默认值 详情
with_input bool False 是否返回输入
with_loss bool False 是否返回损失
save_preds Path None 保存预测结果的路径
save_targs Path None 保存目标的路径
with_preds bool True 是否返回预测结果
with_targs bool True 是否返回目标
concat_dim int 0 拼接返回张量的维度
pickle_protocol int 2 用于保存预测结果和目标的 Pickle 协议

源码

FetchPredsCallback

 FetchPredsCallback (ds_idx:int=1, dl:fastai.data.load.DataLoader=None,
                     with_input:bool=False, with_decoded:bool=False, cbs:_
                     _main__.Callback|collections.abc.MutableSequence=None
                     , reorder:bool=True)

在训练循环期间获取预测结果的回调

类型 默认值 详情
ds_idx int 1 数据集索引,0 表示训练集,1 表示验证集,当 dl 不存在时使用
dl DataLoader None 用于获取 Learner 预测结果的 DataLoader
with_input bool False 是否在 GatherPredsCallback 中返回输入
with_decoded bool False 是否返回解码后的预测结果
cbs main.Callback | collections.abc.MutableSequence None 暂时从 Learner 中移除的 Callback
reorder bool True 是否对预测结果进行排序