from fastai.test_utils import *
进度与日志记录
用于跟踪训练进度或记录结果的回调函数和辅助函数
ProgressCallback
ProgressCallback (after_create=None, before_fit=None, before_epoch=None, before_train=None, before_batch=None, after_pred=None, after_loss=None, before_backward=None, after_cancel_backward=None, after_backward=None, before_step=None, after_cancel_step=None, after_step=None, after_cancel_batch=None, after_batch=None, after_cancel_train=None, after_train=None, before_validate=None, after_cancel_validate=None, after_validate=None, after_cancel_epoch=None, after_epoch=None, after_cancel_fit=None, after_fit=None)
一个 Callback
用于处理进度条显示
= synth_learner()
learn 5) learn.fit(
周期 | 训练损失 | 验证损失 | 时间 |
---|---|---|---|
0 | 14.523648 | 10.988108 | 00:00 |
1 | 12.395808 | 7.306935 | 00:00 |
2 | 10.121231 | 4.370981 | 00:00 |
3 | 8.065226 | 2.487984 | 00:00 |
4 | 6.374166 | 1.368232 | 00:00 |
no_bar
no_bar ()
禁用进度条使用的上下文管理器
= synth_learner()
learn with learn.no_bar(): learn.fit(5)
[0, 15.748106002807617, 12.352150917053223, '00:00']
[1, 13.818815231323242, 8.879858016967773, '00:00']
[2, 11.650713920593262, 5.857329845428467, '00:00']
[3, 9.595088005065918, 3.7397098541259766, '00:00']
[4, 7.814438343048096, 2.327916145324707, '00:00']
ProgressCallback.before_fit
ProgressCallback.before_fit ()
在训练周期上设置主进度条
ProgressCallback.before_epoch
ProgressCallback.before_epoch ()
更新主进度条
ProgressCallback.before_train
ProgressCallback.before_train ()
启动训练数据加载器上的进度条
ProgressCallback.before_validate
ProgressCallback.before_validate ()
启动验证数据加载器上的进度条
ProgressCallback.after_batch
ProgressCallback.after_batch ()
更新当前进度条
ProgressCallback.after_train
ProgressCallback.after_train ()
关闭训练数据加载器上的进度条
ProgressCallback.after_validate
ProgressCallback.after_validate ()
关闭验证数据加载器上的进度条
ProgressCallback.after_fit
ProgressCallback.after_fit ()
关闭主进度条
ShowGraphCallback
ShowGraphCallback (after_create=None, before_fit=None, before_epoch=None, before_train=None, before_batch=None, after_pred=None, after_loss=None, before_backward=None, after_cancel_backward=None, after_backward=None, before_step=None, after_cancel_step=None, after_step=None, after_cancel_batch=None, after_batch=None, after_cancel_train=None, after_train=None, before_validate=None, after_cancel_validate=None, after_validate=None, after_cancel_epoch=None, after_epoch=None, after_cancel_fit=None, after_fit=None)
更新训练损失和验证损失图表
= synth_learner(cbs=ShowGraphCallback())
learn 5) learn.fit(
周期 | 训练损失 | 验证损失 | 时间 |
---|---|---|---|
0 | 17.683565 | 10.431150 | 00:00 |
1 | 15.232769 | 7.056944 | 00:00 |
2 | 12.470916 | 4.382421 | 00:00 |
3 | 10.000675 | 2.574951 | 00:00 |
4 | 7.943449 | 1.464153 | 00:00 |
0.1]])) learn.predict(torch.tensor([[
(tensor([1.8955]), tensor([1.8955]), tensor([1.8955]))
CSVLogger
CSVLogger (fname='history.csv', append=False)
将结果记录到 learn.path/fname 中显示
如果设置 append,则结果追加到现有文件;否则覆盖现有文件。
= synth_learner(cbs=CSVLogger())
learn 5) learn.fit(
周期 | 训练损失 | 验证损失 | 时间 |
---|---|---|---|
0 | 15.606769 | 14.485189 | 00:00 |
1 | 13.840394 | 10.834929 | 00:00 |
2 | 11.842106 | 7.582738 | 00:00 |
3 | 9.937692 | 5.158300 | 00:00 |
4 | 8.244681 | 3.432087 | 00:00 |
CSVLogger.read_log
CSVLogger.read_log ()
快速访问日志的便捷方法。
= learn.csv_logger.read_log()
df
test_eq(df.columns.values, learn.recorder.metric_names)for i,v in enumerate(learn.recorder.values):
3], [i] + v)
test_close(df.iloc[i][:/learn.csv_logger.fname) os.remove(learn.path
CSVLogger.before_fit
CSVLogger.before_fit ()
准备包含指标名称的文件。
CSVLogger.after_fit
CSVLogger.after_fit ()
关闭文件并清理。