'before_step') test_eq(event.before_step,
回调
事件
回调可以在以下任何时间发生: after_create before_fit before_epoch before_train before_batch after_pred after_loss before_backward after_cancel_backward after_backward before_step after_cancel_step after_step after_cancel_batch after_batch after_cancel_train after_train before_validate after_cancel_validate after_validate after_cancel_epoch after_epoch after_cancel_fit after_fit。
event
event (*args, **kwargs)
所有可能的事件作为属性,用于标签补全和防错别字
为确保您引用的是一个存在的事件(即回调被调用的某个时间点的名称),并获得事件名称的标签补全,请使用 event
Callback
Callback (after_create=None, before_fit=None, before_epoch=None, before_train=None, before_batch=None, after_pred=None, after_loss=None, before_backward=None, after_cancel_backward=None, after_backward=None, before_step=None, after_cancel_step=None, after_step=None, after_cancel_batch=None, after_batch=None, after_cancel_train=None, after_train=None, before_validate=None, after_cancel_validate=None, after_validate=None, after_cancel_epoch=None, after_epoch=None, after_cancel_fit=None, after_fit=None)
通过在各种事件中更改 Learner
来处理训练循环调整的基本类
训练循环在下方的 Learner
中定义,包含一组最小指令:遍历数据,我们
- 计算模型根据输入的输出
- 计算此输出与期望目标之间的损失
- 计算此损失相对于所有模型参数的梯度
- 相应地更新参数
- 将所有梯度归零
对此训练循环的任何调整都在 Callback
中定义,以避免训练循环的代码过于复杂,并使其易于混合和匹配不同的技术(因为它们将在不同的回调中定义)。回调可以在以下事件上实现操作
after_create
:在Learner
创建后调用before_fit
:在开始训练或推理前调用,非常适合初始设置。before_epoch
:在每个 epoch 开始时调用,用于在每个 epoch 需要重置的任何行为。before_train
:在 epoch 的训练阶段开始时调用。before_batch
:在每个批次开始时调用,就在获取该批次之后。可用于进行该批次的任何必要设置(如超参数调度),或在输入/目标进入模型之前更改它们(例如使用 mixup 等技术更改输入)。after_pred
:在计算模型在该批次上的输出后调用。可用于在输出馈送到损失函数之前更改它。after_loss
:在计算损失后调用,但在反向传播之前。可用于向损失添加任何惩罚项(例如 RNN 训练中的 AR 或 TAR)。before_backward
:在计算损失后调用,但仅在训练模式下(即使用反向传播时)after_backward
:在反向传播后调用,但在参数更新之前。通常应该改用before_step
。before_step
:在反向传播后调用,但在参数更新之前。可用于在所述更新之前对梯度进行任何更改(例如梯度裁剪)。after_step
:在步骤执行后且梯度归零之前调用。after_batch
:在一个批次结束时调用,用于在下一个批次之前进行任何清理工作。after_train
:在 epoch 的训练阶段结束时调用。before_validate
:在 epoch 的验证阶段开始时调用,用于专门为验证所需的任何设置。after_validate
:在 epoch 的验证阶段结束时调用。after_epoch
:在一个 epoch 结束时调用,用于在下一个 epoch 之前进行任何清理工作。after_fit
:在训练结束时调用,用于最终清理工作。
Callback.__call__
Callback.__call__ (event_name)
如果定义了 self.{event_name}
,则调用它
定义回调的一种方法是通过子类化
class _T(Callback):
def call_me(self): return "maybe"
"call_me"), "maybe") test_eq(_T()(
另一种方法是将回调函数传递给构造函数
def cb(self): return "maybe"
= Callback(before_fit=cb)
_t "maybe") test_eq(_t(event.before_fit),
Callback
提供了快捷方式,避免了对于任何我们查找的 bla
属性,都必须写 self.learn.bla
;相反,只需写 self.bla
。这仅适用于获取属性,不适用于设置属性。
'TstLearner', 'a')
mk_class(
class TstCallback(Callback):
def batch_begin(self): print(self.a)
= TstLearner(1),TstCallback()
learn,cb = learn
cb.learn lambda: cb('batch_begin'), "1") test_stdout(
如果要更改属性的值,必须使用 self.learn.bla
,而不是 self.bla
。在下面的示例中,self.a += 1
在回调中创建了一个值为 2 的 a
属性,而不是将 learner 的 a
设置为 2。它还会发出警告,提示可能出了问题
learn.a
1
class TstCallback(Callback):
def batch_begin(self): self.a += 1
= TstLearner(1),TstCallback()
learn,cb = learn
cb.learn 'batch_begin')
cb(2)
test_eq(cb.a, 1) test_eq(cb.learn.a,
/tmp/ipykernel_5201/1369389649.py:29: UserWarning: You are shadowing an attribute (a) that exists in the learner. Use `self.learn.a` to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
正确版本需要写 self.learn.a = self.a + 1
class TstCallback(Callback):
def batch_begin(self): self.learn.a = self.a + 1
= TstLearner(1),TstCallback()
learn,cb = learn
cb.learn 'batch_begin')
cb(2) test_eq(cb.learn.a,
Callback.name
Callback.name ()
Callback
的名称,使用驼峰命名法,并去掉 ‘Callback’
'tst')
test_eq(TstCallback().name, class ComplicatedNameCallback(Callback): pass
'complicated_name') test_eq(ComplicatedNameCallback().name,
TrainEvalCallback
TrainEvalCallback (after_create=None, before_fit=None, before_epoch=None, before_train=None, before_batch=None, after_pred=None, after_loss=None, before_backward=None, after_cancel_backward=None, after_backward=None, before_step=None, after_cancel_step=None, after_step=None, after_cancel_batch=None, after_batch=None, after_cancel_train=None, after_train=None, before_validate=None, after_cancel_validate=None, after_validate=None, after_cancel_epoch=None, after_epoch=None, after_cancel_fit=None, after_fit=None)
跟踪已完成的迭代次数并正确设置训练/评估模式的 Callback
回调可用的属性
编写回调时,可以使用 Learner
的以下属性
model
:用于训练/验证的模型dls
:底层DataLoaders
loss_func
:使用的损失函数opt
:用于更新模型参数的优化器opt_func
:用于创建优化器的函数cbs
:包含所有Callback
的列表dl
:当前用于迭代的DataLoader
x
/xb
:从self.dl
中提取的最后一个输入(可能被回调修改)。xb
始终是一个元组(可能包含一个元素),而x
是解元组后的结果。您只能赋值给xb
。y
/yb
:从self.dl
中提取的最后一个目标(可能被回调修改)。yb
始终是一个元组(可能包含一个元素),而y
是解元组后的结果。您只能赋值给yb
。pred
:从self.model
获得的最后一个预测结果(可能被回调修改)loss_grad
:最后一个计算出的损失(可能被回调修改)loss
:loss_grad
的克隆,用于日志记录n_epoch
:本次训练的总 epoch 数n_iter
:当前self.dl
中的迭代次数epoch
:当前 epoch 索引(从 0 到n_epoch-1
)iter
:当前self.dl
中的迭代索引(从 0 到n_iter-1
)
以下属性由 TrainEvalCallback
添加,除非您特意将其移除,否则应该可用
train_iter
:自本次训练开始以来完成的训练迭代次数pct_train
:从 0. 到 1.,已完成训练迭代的百分比training
:指示当前是否处于训练模式的标志
以下属性由 Recorder
添加,除非您特意将其移除,否则应该可用
smooth_loss
:训练损失的指数平均版本
回调控制流程
有时我们可能想跳过训练循环中的某些步骤:例如,在梯度累积中,我们不总是想执行步进/梯度归零操作。在 LR finder 测试期间,我们不想执行 epoch 的验证阶段。或者,如果我们使用早停策略进行训练,我们希望能够完全中断训练循环。
通过引发训练循环将寻找(并正确捕获)的特定异常,这成为可能。
CancelStepException
CancelStepException (*args, **kwargs)
跳过优化器的步进
CancelBatchException
CancelBatchException (*args, **kwargs)
跳过当前批次的其余部分,直接进入 after_batch
CancelBackwardException
CancelBackwardException (*args, **kwargs)
跳过反向传播,直接进入 after_backward
CancelTrainException
CancelTrainException (*args, **kwargs)
跳过 epoch 训练部分的其余部分,直接进入 after_train
CancelValidException
CancelValidException (*args, **kwargs)
跳过 epoch 验证部分的其余部分,直接进入 after_validate
CancelEpochException
CancelEpochException (*args, **kwargs)
跳过当前 epoch 的其余部分,直接进入 after_epoch
CancelFitException
CancelFitException (*args, **kwargs)
中断训练,直接进入 after_fit
您可以通过以下事件检测到其中一个异常的发生,并在其后立即添加执行的代码
after_cancel_batch
:在发生CancelBatchException
后立即到达,之后进入after_batch
after_cancel_train
:在发生CancelTrainException
后立即到达,之后进入after_epoch
after_cancel_valid
:在发生CancelValidException
后立即到达,之后进入after_epoch
after_cancel_epoch
:在发生CancelEpochException
后立即到达,之后进入after_epoch
after_cancel_fit
:在发生CancelFitException
后立即到达,之后进入after_fit
GatherPredsCallback
GatherPredsCallback (with_input:bool=False, with_loss:bool=False, save_preds:pathlib.Path=None, save_targs:pathlib.Path=None, with_preds:bool=True, with_targs:bool=True, concat_dim:int=0, pickle_protocol:int=2)
返回所有预测和目标,可选包含 with_input
或 with_loss
的 Callback
类型 | 默认值 | 详情 | |
---|---|---|---|
with_input | bool | False | 是否返回输入 |
with_loss | bool | False | 是否返回损失 |
save_preds | Path | None | 保存预测结果的路径 |
save_targs | Path | None | 保存目标的路径 |
with_preds | bool | True | 是否返回预测结果 |
with_targs | bool | True | 是否返回目标 |
concat_dim | int | 0 | 拼接返回张量的维度 |
pickle_protocol | int | 2 | 用于保存预测结果和目标的 Pickle 协议 |
FetchPredsCallback
FetchPredsCallback (ds_idx:int=1, dl:fastai.data.load.DataLoader=None, with_input:bool=False, with_decoded:bool=False, cbs:_ _main__.Callback|collections.abc.MutableSequence=None , reorder:bool=True)
在训练循环期间获取预测结果的回调
类型 | 默认值 | 详情 | |
---|---|---|---|
ds_idx | int | 1 | 数据集索引,0 表示训练集,1 表示验证集,当 dl 不存在时使用 |
dl | DataLoader | None | 用于获取 Learner 预测结果的 DataLoader |
with_input | bool | False | 是否在 GatherPredsCallback 中返回输入 |
with_decoded | bool | False | 是否返回解码后的预测结果 |
cbs | main.Callback | collections.abc.MutableSequence | None | 暂时从 Learner 中移除的 Callback |
reorder | bool | True | 是否对预测结果进行排序 |