预测结果解读

用于构建对象以更好地解释模型预测结果的类

源代码

解读

 Interpretation (learn:fastai.learner.Learner,
                 dl:fastai.data.load.DataLoader,
                 losses:fastai.torch_core.TensorBase, act=None)

Interpretation 基类,可以被继承用于特定任务的 Interpretation 类

类型 默认值 详情
learn Learner
dl DataLoader DataLoader 用于运行推理
losses TensorBase 根据 dl 计算的损失值
act NoneType None 用于预测的激活函数

Interpretation 是一个辅助基类,用于探索训练模型的预测结果。它可以被继承用于特定任务的解释类,例如 ClassificationInterpretationInterpretation 内存效率高,并且只要硬件能够训练相同的模型,它应该能够处理任何大小的数据集。

注意

由于在可能的情况下使用批量处理,并为每个项目动态生成输入、预测、目标、解码输出和损失,因此 Interpretation 具有内存效率。


源代码

Interpretation.from_learner

 Interpretation.from_learner (learn, ds_idx:int=1,
                              dl:fastai.data.load.DataLoader=None,
                              act=None)

从 Learner 构建 Interpretation 对象

类型 默认值 详情
learn 用于创建 Interpretation 的模型
ds_idx int 1 dl 为 None 时,learn.dls 的索引
dl DataLoader None 用于生成预测结果的 Dataloader
act NoneType None 覆盖默认的或设置预测激活函数

源代码

Interpretation.top_losses

 Interpretation.top_losses (k:int|None=None, largest:bool=True,
                            items:bool=False)

k 个最大(/最小)损失值及其索引,默认为所有损失值。

类型 默认值 详情
k int | None None 返回 k 个损失值,默认为所有
largest bool True 按最大或最小对损失值排序
items bool False 是否返回输入项

在默认设置 k=None 下,top_losses 将返回整个数据集的损失值。top_losses 可以选择包含每个损失值对应的输入项,通常是一个文件路径或 Pandas DataFrame


源代码

Interpretation.plot_top_losses

 Interpretation.plot_top_losses (k:int|collections.abc.MutableSequence,
                                 largest:bool=True, **kwargs)

显示 k 个最大(/最小)预测结果和损失值。实现基于类型分发

类型 默认值 详情
k int | collections.abc.MutableSequence 要绘制的损失值数量
largest bool True 按最大或最小对损失值排序
kwargs VAR_KEYWORD

绘制前 9 个最大损失值

interp = Interpretation.from_learner(learn)
interp.plot_top_losses(9)

然后绘制第 7 到第 16 个最大损失值

interp.plot_top_losses(range(7,16))

源代码

Interpretation.show_results

 Interpretation.show_results (idxs:list, **kwargs)

显示 idxs 的预测结果和目标值

类型 详情
idxs list 预测结果和目标的索引
kwargs VAR_KEYWORD

Learner.show_results 类似,但可以传递期望的项目索引来显示结果。


源代码

ClassificationInterpretation

 ClassificationInterpretation (learn:fastai.learner.Learner,
                               dl:fastai.data.load.DataLoader,
                               losses:fastai.torch_core.TensorBase,
                               act=None)

用于分类模型的 Interpretation 方法。

类型 默认值 详情
learn Learner
dl DataLoader DataLoader 用于运行推理
losses TensorBase 根据 dl 计算的损失值
act NoneType None 用于预测的激活函数

源代码

ClassificationInterpretation.confusion_matrix

 ClassificationInterpretation.confusion_matrix ()

混淆矩阵,表示为 np.ndarray


源代码

ClassificationInterpretation.plot_confusion_matrix

 ClassificationInterpretation.plot_confusion_matrix (normalize:bool=False,
                                                     title:str='Confusion
                                                     matrix',
                                                     cmap:str='Blues',
                                                     norm_dec:int=2,
                                                     plot_txt:bool=True,
                                                     **kwargs)

绘制混淆矩阵,带有 title 并使用 cmap

类型 默认值 详情
normalize bool False 是否归一化出现次数
title str 混淆矩阵 图表标题
cmap str Blues matplotlib 的颜色映射
norm_dec int 2 归一化出现次数的小数位数
plot_txt bool True 在矩阵中显示出现次数
kwargs VAR_KEYWORD

源代码

ClassificationInterpretation.most_confused

 ClassificationInterpretation.most_confused (min_val=1)

按降序排列的混淆矩阵中最大的非对角线项(实际类别,预测类别,# 出现次数)


源代码

SegmentationInterpretation

 SegmentationInterpretation (learn:fastai.learner.Learner,
                             dl:fastai.data.load.DataLoader,
                             losses:fastai.torch_core.TensorBase,
                             act=None)

用于分割模型的 Interpretation 方法。

类型 默认值 详情
learn Learner
dl DataLoader DataLoader 用于运行推理
losses TensorBase 根据 dl 计算的损失值
act NoneType None 用于预测的激活函数