进度与日志记录

用于跟踪训练进度或记录结果的回调函数和辅助函数
from fastai.test_utils import *

源文件

ProgressCallback

 ProgressCallback (after_create=None, before_fit=None, before_epoch=None,
                   before_train=None, before_batch=None, after_pred=None,
                   after_loss=None, before_backward=None,
                   after_cancel_backward=None, after_backward=None,
                   before_step=None, after_cancel_step=None,
                   after_step=None, after_cancel_batch=None,
                   after_batch=None, after_cancel_train=None,
                   after_train=None, before_validate=None,
                   after_cancel_validate=None, after_validate=None,
                   after_cancel_epoch=None, after_epoch=None,
                   after_cancel_fit=None, after_fit=None)

一个 Callback 用于处理进度条显示

learn = synth_learner()
learn.fit(5)
周期 训练损失 验证损失 时间
0 14.523648 10.988108 00:00
1 12.395808 7.306935 00:00
2 10.121231 4.370981 00:00
3 8.065226 2.487984 00:00
4 6.374166 1.368232 00:00

no_bar

 no_bar ()

禁用进度条使用的上下文管理器

learn = synth_learner()
with learn.no_bar(): learn.fit(5)
[0, 15.748106002807617, 12.352150917053223, '00:00']
[1, 13.818815231323242, 8.879858016967773, '00:00']
[2, 11.650713920593262, 5.857329845428467, '00:00']
[3, 9.595088005065918, 3.7397098541259766, '00:00']
[4, 7.814438343048096, 2.327916145324707, '00:00']

源文件

ProgressCallback.before_fit

 ProgressCallback.before_fit ()

在训练周期上设置主进度条


源文件

ProgressCallback.before_epoch

 ProgressCallback.before_epoch ()

更新主进度条


源文件

ProgressCallback.before_train

 ProgressCallback.before_train ()

启动训练数据加载器上的进度条


源文件

ProgressCallback.before_validate

 ProgressCallback.before_validate ()

启动验证数据加载器上的进度条


源文件

ProgressCallback.after_batch

 ProgressCallback.after_batch ()

更新当前进度条


源文件

ProgressCallback.after_train

 ProgressCallback.after_train ()

关闭训练数据加载器上的进度条


源文件

ProgressCallback.after_validate

 ProgressCallback.after_validate ()

关闭验证数据加载器上的进度条


源文件

ProgressCallback.after_fit

 ProgressCallback.after_fit ()

关闭主进度条


源文件

ShowGraphCallback

 ShowGraphCallback (after_create=None, before_fit=None, before_epoch=None,
                    before_train=None, before_batch=None, after_pred=None,
                    after_loss=None, before_backward=None,
                    after_cancel_backward=None, after_backward=None,
                    before_step=None, after_cancel_step=None,
                    after_step=None, after_cancel_batch=None,
                    after_batch=None, after_cancel_train=None,
                    after_train=None, before_validate=None,
                    after_cancel_validate=None, after_validate=None,
                    after_cancel_epoch=None, after_epoch=None,
                    after_cancel_fit=None, after_fit=None)

更新训练损失和验证损失图表

learn = synth_learner(cbs=ShowGraphCallback())
learn.fit(5)
周期 训练损失 验证损失 时间
0 17.683565 10.431150 00:00
1 15.232769 7.056944 00:00
2 12.470916 4.382421 00:00
3 10.000675 2.574951 00:00
4 7.943449 1.464153 00:00

learn.predict(torch.tensor([[0.1]]))
(tensor([1.8955]), tensor([1.8955]), tensor([1.8955]))

源文件

CSVLogger

 CSVLogger (fname='history.csv', append=False)

将结果记录到 learn.path/fname 中显示

如果设置 append,则结果追加到现有文件;否则覆盖现有文件。

learn = synth_learner(cbs=CSVLogger())
learn.fit(5)
周期 训练损失 验证损失 时间
0 15.606769 14.485189 00:00
1 13.840394 10.834929 00:00
2 11.842106 7.582738 00:00
3 9.937692 5.158300 00:00
4 8.244681 3.432087 00:00

源文件

CSVLogger.read_log

 CSVLogger.read_log ()

快速访问日志的便捷方法。

df = learn.csv_logger.read_log()
test_eq(df.columns.values, learn.recorder.metric_names)
for i,v in enumerate(learn.recorder.values):
    test_close(df.iloc[i][:3], [i] + v)
os.remove(learn.path/learn.csv_logger.fname)

源文件

CSVLogger.before_fit

 CSVLogger.before_fit ()

准备包含指标名称的文件。


源文件

CSVLogger.after_fit

 CSVLogger.after_fit ()

关闭文件并清理。