Captum

在本笔记本中,我们将使用以下数据

from fastai.vision.all import *
path = untar_data(URLs.PETS)/'images'
fnames = get_image_files(path)
def is_cat(x): return x[0].isupper()
dls = ImageDataLoaders.from_name_func(
    path, fnames, valid_pct=0.2, seed=42,
    label_func=is_cat, item_tfms=Resize(128))
from random import randint
learn = vision_learner(dls, resnet34, metrics=error_rate)
learn.fine_tune(1)

Captum 解释

这篇 Distill 文章(此处)提供了关于选择何种基线图像的良好概述。我们可以逐一尝试。


源文件

CaptumInterpretation

 CaptumInterpretation (learn, cmap_name='custom blue', colors=None, N=256,
                       methods=('original_image', 'heat_map'),
                       signs=('all', 'positive'), outlier_perc=1)

Resnet 的 Captum 解释

解释

captum=CaptumInterpretation(learn)
idx=randint(0,len(fnames))
captum.visualize(fnames[idx])

captum.visualize(fnames[idx],baseline_type='uniform')

captum.visualize(fnames[idx],baseline_type='gauss')

captum.visualize(fnames[idx],metric='NT',baseline_type='uniform')

captum.visualize(fnames[idx],metric='Occl',baseline_type='gauss')

Captum Insights 回调

@patch
def _formatted_data_iter(x: CaptumInterpretation,dl,normalize_func):
    dl_iter=iter(dl)
    while True:
        images,labels=next(dl_iter)
        images=normalize_func.decode(images).to(dl.device)
        yield Batch(inputs=images, labels=labels)

源文件

CaptumInterpretation.insights

 CaptumInterpretation.insights (x:__main__.CaptumInterpretation, inp_data,
                                debug=True)
captum=CaptumInterpretation(learn)
captum.insights(fnames)