预测回调

用于自定义 get_preds 行为的各种回调

MCDropoutCallback

在推理期间开启 dropout,允许您多次调用 Learner.get_preds 以使用 Monte Carlo Dropout 近似模型的预测不确定性。


源文件

MCDropoutCallback

 MCDropoutCallback (after_create=None, before_fit=None, before_epoch=None,
                    before_train=None, before_batch=None, after_pred=None,
                    after_loss=None, before_backward=None,
                    after_cancel_backward=None, after_backward=None,
                    before_step=None, after_cancel_step=None,
                    after_step=None, after_cancel_batch=None,
                    after_batch=None, after_cancel_train=None,
                    after_train=None, before_validate=None,
                    after_cancel_validate=None, after_validate=None,
                    after_cancel_epoch=None, after_epoch=None,
                    after_cancel_fit=None, after_fit=None)

处理训练循环调整的基本类,通过在各种事件中改变 Learner 实现

learn = synth_learner()

# Call get_preds 10 times, then stack the predictions, yielding a tensor with shape [# of samples, batch_size, ...]
dist_preds = []
for i in range(10):
    preds, targs = learn.get_preds(cbs=[MCDropoutCallback()])
    dist_preds += [preds]

torch.stack(dist_preds).shape
torch.Size([10, 32, 1])