指标

可用于模型训练的指标定义

核心指标

本节定义了将 scikit-learn 指标转换为 fastai 指标的函数。除非您想了解 fastai 的所有内部细节,否则可以跳过本节。


源代码

AccumMetric

 AccumMetric (func, dim_argmax=None, activation='no', thresh=None,
              to_np=False, invert_arg=False, flatten=True, name=None,
              **kwargs)

在 CPU 上累积存储预测和目标,以便使用 func 执行最终计算。

仅在请求 value 属性时(例如在验证/训练阶段结束时,与 Learner 及其 Recorder 配合使用)才将 func 应用于累积的预测/目标。func 的签名应为 inp,targ(其中 inp 是模型的预测,targ 是相应的标签)。

对于单标签分类问题,预测需要先经过 softmax 然后 argmax 转换,再与目标进行比较。由于 softmax 不改变数字的顺序,我们可以只应用 argmax。传递 dim_argmax 参数即可让 AccumMetric 执行此操作(通常 -1 效果很好)。如果您需要将概率而不是预测传递给指标,请使用 softmax=True

对于多标签分类问题,或者如果您的目标是 one-hot 编码,预测可能需要先经过 sigmoid(如果模型中未包含)然后与给定阈值进行比较(用于决定 0 和 1),如果传递 sigmoid=True 和/或 thresh 的值,AccumMetric 会完成此操作。

如果您想使用 scikit-learn.metrics 的指标函数,您需要使用 to_np=True 将预测和标签转换为 numpy 数组。此外,scikit-learn 指标采用 y_truey_preds 的约定,这与我们相反,因此您需要传递 invert_arg=True 来让 AccumMetric 为您进行反转。

#For testing: a fake learner and a metric that isn't an average
@delegates()
class TstLearner(Learner):
    def __init__(self,dls=None,model=None,**kwargs): self.pred,self.xb,self.yb = None,None,None
def _l2_mean(x,y): return torch.sqrt((x.float()-y.float()).pow(2).mean())

#Go through a fake cycle with various batch sizes and computes the value of met
def compute_val(met, x1, x2):
    met.reset()
    vals = [0,6,15,20]
    learn = TstLearner()
    for i in range(3):
        learn.pred,learn.yb = x1[vals[i]:vals[i+1]],(x2[vals[i]:vals[i+1]],)
        met.accumulate(learn)
    return met.value
x1,x2 = torch.randn(20,5),torch.randn(20,5)
tst = AccumMetric(_l2_mean)
test_close(compute_val(tst, x1, x2), _l2_mean(x1, x2))
test_eq(torch.cat(tst.preds), x1.view(-1))
test_eq(torch.cat(tst.targs), x2.view(-1))

#test argmax
x1,x2 = torch.randn(20,5),torch.randint(0, 5, (20,))
tst = AccumMetric(_l2_mean, dim_argmax=-1)
test_close(compute_val(tst, x1, x2), _l2_mean(x1.argmax(dim=-1), x2))

#test thresh
x1,x2 = torch.randn(20,5),torch.randint(0, 2, (20,5)).bool()
tst = AccumMetric(_l2_mean, thresh=0.5)
test_close(compute_val(tst, x1, x2), _l2_mean((x1 >= 0.5), x2))

#test sigmoid
x1,x2 = torch.randn(20,5),torch.randn(20,5)
tst = AccumMetric(_l2_mean, activation=ActivationType.Sigmoid)
test_close(compute_val(tst, x1, x2), _l2_mean(torch.sigmoid(x1), x2))

#test to_np
x1,x2 = torch.randn(20,5),torch.randn(20,5)
tst = AccumMetric(lambda x,y: isinstance(x, np.ndarray) and isinstance(y, np.ndarray), to_np=True)
assert compute_val(tst, x1, x2)

#test invert_arg
x1,x2 = torch.randn(20,5),torch.randn(20,5)
tst = AccumMetric(lambda x,y: torch.sqrt(x.pow(2).mean()))
test_close(compute_val(tst, x1, x2), torch.sqrt(x1.pow(2).mean()))
tst = AccumMetric(lambda x,y: torch.sqrt(x.pow(2).mean()), invert_arg=True)
test_close(compute_val(tst, x1, x2), torch.sqrt(x2.pow(2).mean()))

源代码

skm_to_fastai

 skm_to_fastai (func, is_class=True, thresh=None, axis=-1,
                activation=None, **kwargs)

将 scikit-learn.metrics 的 func 转换为 fastai 指标

这是在 fastai 训练循环中使用 scikit-learn 指标最快捷的方法。is_class 指示您是否在处理分类问题。在这种情况下

  • thresh 设置为 None 表示这是单标签分类问题,预测将先经过 axis 上的 argmax,再与目标进行比较
  • thresh 设置一个值表示这是多标签分类问题,预测将先经过 sigmoid(可以通过 sigmoid=False 关闭)并与 thresh 比较,再与目标进行比较

如果 is_class=False,则表示您在处理回归问题,预测直接与目标进行比较,不进行修改。在所有情况下,kwargs 是传递给 func 的额外关键字参数。

tst_single = skm_to_fastai(skm.precision_score)
x1,x2 = torch.randn(20,2),torch.randint(0, 2, (20,))
test_close(compute_val(tst_single, x1, x2), skm.precision_score(x2, x1.argmax(dim=-1)))
tst_multi = skm_to_fastai(skm.precision_score, thresh=0.2)
x1,x2 = torch.randn(20),torch.randint(0, 2, (20,))
test_close(compute_val(tst_multi, x1, x2), skm.precision_score(x2, torch.sigmoid(x1) >= 0.2))

tst_multi = skm_to_fastai(skm.precision_score, thresh=0.2, activation=ActivationType.No)
x1,x2 = torch.randn(20),torch.randint(0, 2, (20,))
test_close(compute_val(tst_multi, x1, x2), skm.precision_score(x2, x1 >= 0.2))
tst_reg = skm_to_fastai(skm.r2_score, is_class=False)
x1,x2 = torch.randn(20,5),torch.randn(20,5)
test_close(compute_val(tst_reg, x1, x2), skm.r2_score(x2.view(-1).numpy(), x1.view(-1).numpy()))
test_close(tst_reg(x1, x2), skm.r2_score(x2.view(-1).numpy(), x1.view(-1).numpy()))

源代码

optim_metric

 optim_metric (f, argname, bounds, tol=0.01, do_neg=True, get_x=False)

将指标 f 替换为优化参数 argname 的版本

单标签分类

警告

本节中定义的所有函数都适用于单标签分类和非 one-hot 编码的目标。对于多标签问题或 one-hot 编码的目标,请使用带有 multi 后缀的版本。

警告

fastai 中的许多指标是对 sklearn 功能的轻量级封装。然而,sklearn 指标可以处理 Python 列表字符串等,而 fastai 指标与 PyTorch 配合使用,因此需要张量。传递给指标的参数是在所有转换(例如类别转换为索引)发生之后的值。这意味着,例如,当您传递指标的标签时,必须传递索引而不是字符串。这可以使用 vocab.map_obj 进行转换。


源代码

准确率

 accuracy (inp, targ, axis=-1)

pred 的形状为 bs * n_classes* 时计算与 targ 的准确率

#For testing
def change_targ(targ, n, c):
    idx = torch.randperm(len(targ))[:n]
    res = targ.clone()
    for i in idx: res[i] = (res[i]+random.randint(1,c-1))%c
    return res
x = torch.randn(4,5)
y = x.argmax(dim=1)
test_eq(accuracy(x,y), 1)
y1 = change_targ(y, 2, 5)
test_eq(accuracy(x,y1), 0.5)
test_eq(accuracy(x.unsqueeze(1).expand(4,2,5), torch.stack([y,y1], dim=1)), 0.75)

源代码

错误率

 error_rate (inp, targ, axis=-1)

1 - accuracy

x = torch.randn(4,5)
y = x.argmax(dim=1)
test_eq(error_rate(x,y), 0)
y1 = change_targ(y, 2, 5)
test_eq(error_rate(x,y1), 0.5)
test_eq(error_rate(x.unsqueeze(1).expand(4,2,5), torch.stack([y,y1], dim=1)), 0.25)

源代码

Top-k 准确率

 top_k_accuracy (inp, targ, k=5, axis=-1)

计算 Top-k 准确率(即 targinp 的前 k 个预测中)

x = torch.randn(6,5)
y = torch.arange(0,6)
test_eq(top_k_accuracy(x[:5],y[:5]), 1)
test_eq(top_k_accuracy(x, y), 5/6)

源代码

二分类平均精度得分

 APScoreBinary (axis=-1, average='macro', pos_label=1, sample_weight=None)

单标签二分类问题的平均精度

更多详细信息请参阅scikit-learn 文档


源代码

平衡准确率

 BalancedAccuracy (axis=-1, sample_weight=None, adjusted=False)

单标签二分类问题的平衡准确率

更多详细信息请参阅scikit-learn 文档


源代码

Brier 得分

 BrierScore (axis=-1, sample_weight=None, pos_label=None)

单标签分类问题的 Brier 得分

更多详细信息请参阅scikit-learn 文档


源代码

Cohen Kappa 系数

 CohenKappa (axis=-1, labels=None, weights=None, sample_weight=None)

单标签分类问题的 Cohen Kappa 系数

更多详细信息请参阅scikit-learn 文档


源代码

F1 得分

 F1Score (axis=-1, labels=None, pos_label=1, average='binary',
          sample_weight=None)

单标签分类问题的 F1 得分

更多详细信息请参阅scikit-learn 文档


源代码

FBeta 得分

 FBeta (beta, axis=-1, labels=None, pos_label=1, average='binary',
        sample_weight=None)

单标签分类问题中带有 beta 的 FBeta 得分

更多详细信息请参阅scikit-learn 文档


源代码

汉明损失

 HammingLoss (axis=-1, sample_weight=None)

单标签分类问题的汉明损失

更多详细信息请参阅scikit-learn 文档


源代码

Jaccard 得分

 Jaccard (axis=-1, labels=None, pos_label=1, average='binary',
          sample_weight=None)

单标签分类问题的 Jaccard 得分

更多详细信息请参阅scikit-learn 文档


源代码

精确率

 Precision (axis=-1, labels=None, pos_label=1, average='binary',
            sample_weight=None)

单标签分类问题的精确率

更多详细信息请参阅scikit-learn 文档


源代码

召回率

 Recall (axis=-1, labels=None, pos_label=1, average='binary',
         sample_weight=None)

单标签分类问题的召回率

更多详细信息请参阅scikit-learn 文档


源代码

ROC AUC

 RocAuc (axis=-1, average='macro', sample_weight=None, max_fpr=None,
         multi_class='ovr')

单标签多类别分类问题的接收者操作特征曲线下面积 (ROC AUC)

更多详细信息请参阅scikit-learn 文档


源代码

二分类 ROC AUC

 RocAucBinary (axis=-1, average='macro', sample_weight=None, max_fpr=None,
               multi_class='raise')

单标签二分类问题的接收者操作特征曲线下面积 (ROC AUC)

更多详细信息请参阅scikit-learn 文档


源代码

Matthews 相关系数

 MatthewsCorrCoef (sample_weight=None, **kwargs)

单标签分类问题的 Matthews 相关系数

更多详细信息请参阅scikit-learn 文档

多标签分类


源代码

多标签准确率

 accuracy_multi (inp, targ, thresh=0.5, sigmoid=True)

inptarg 大小相同时计算准确率。

#For testing
def change_1h_targ(targ, n):
    idx = torch.randperm(targ.numel())[:n]
    res = targ.clone().view(-1)
    for i in idx: res[i] = 1-res[i]
    return res.view(targ.shape)
x = torch.randn(4,5)
y = (torch.sigmoid(x) >= 0.5).byte()
test_eq(accuracy_multi(x,y), 1)
test_eq(accuracy_multi(x,1-y), 0)
y1 = change_1h_targ(y, 5)
test_eq(accuracy_multi(x,y1), 0.75)

#Different thresh
y = (torch.sigmoid(x) >= 0.2).byte()
test_eq(accuracy_multi(x,y, thresh=0.2), 1)
test_eq(accuracy_multi(x,1-y, thresh=0.2), 0)
y1 = change_1h_targ(y, 5)
test_eq(accuracy_multi(x,y1, thresh=0.2), 0.75)

#No sigmoid
y = (x >= 0.5).byte()
test_eq(accuracy_multi(x,y, sigmoid=False), 1)
test_eq(accuracy_multi(x,1-y, sigmoid=False), 0)
y1 = change_1h_targ(y, 5)
test_eq(accuracy_multi(x,y1, sigmoid=False), 0.75)

源代码

多标签平均精度得分

 APScoreMulti (sigmoid=True, average='macro', pos_label=1,
               sample_weight=None)

多标签分类问题的平均精度

更多详细信息请参阅scikit-learn 文档


源代码

多标签 Brier 得分

 BrierScoreMulti (thresh=0.5, sigmoid=True, sample_weight=None,
                  pos_label=None)

多标签分类问题的 Brier 得分

更多详细信息请参阅scikit-learn 文档


源代码

多标签 F1 得分

 F1ScoreMulti (thresh=0.5, sigmoid=True, labels=None, pos_label=1,
               average='macro', sample_weight=None)

多标签分类问题的 F1 得分

更多详细信息请参阅scikit-learn 文档


源代码

多标签 FBeta 得分

 FBetaMulti (beta, thresh=0.5, sigmoid=True, labels=None, pos_label=1,
             average='macro', sample_weight=None)

多标签分类问题中带有 beta 的 FBeta 得分

更多详细信息请参阅scikit-learn 文档


源代码

多标签汉明损失

 HammingLossMulti (thresh=0.5, sigmoid=True, labels=None,
                   sample_weight=None)

多标签分类问题的汉明损失

更多详细信息请参阅scikit-learn 文档


源代码

多标签 Jaccard 得分

 JaccardMulti (thresh=0.5, sigmoid=True, labels=None, pos_label=1,
               average='macro', sample_weight=None)

多标签分类问题的 Jaccard 得分

更多详细信息请参阅scikit-learn 文档


源代码

多标签 Matthews 相关系数

 MatthewsCorrCoefMulti (thresh=0.5, sigmoid=True, sample_weight=None)

多标签分类问题的 Matthews 相关系数

更多详细信息请参阅scikit-learn 文档


源代码

多标签精确率

 PrecisionMulti (thresh=0.5, sigmoid=True, labels=None, pos_label=1,
                 average='macro', sample_weight=None)

多标签分类问题的精确率

更多详细信息请参阅scikit-learn 文档


源代码

多标签召回率

 RecallMulti (thresh=0.5, sigmoid=True, labels=None, pos_label=1,
              average='macro', sample_weight=None)

多标签分类问题的召回率

更多详细信息请参阅scikit-learn 文档


源代码

多标签 ROC AUC

 RocAucMulti (sigmoid=True, average='macro', sample_weight=None,
              max_fpr=None)

多标签二分类问题的接收者操作特征曲线下面积 (ROC AUC)

roc_auc_metric = RocAucMulti(sigmoid=False)
x,y = torch.tensor([np.arange(start=0, stop=0.2, step=0.04)]*20), torch.tensor([0, 0, 1, 1]).repeat(5)
assert compute_val(roc_auc_metric, x, y) == 0.5
/var/folders/ss/34z569j921v58v8n1n_8z7h40000gn/T/ipykernel_38355/1899176771.py:2: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /Users/runner/work/_temp/anaconda/conda-bld/pytorch_1712608632396/work/torch/csrc/utils/tensor_new.cpp:277.)
  x,y = torch.tensor([np.arange(start=0, stop=0.2, step=0.04)]*20), torch.tensor([0, 0, 1, 1]).repeat(5)

更多详细信息请参阅scikit-learn 文档

回归


源代码

均方误差

 mse (inp, targ)

计算 inptarg 之间的均方误差。

x1,x2 = torch.randn(4,5),torch.randn(4,5)
test_close(mse(x1,x2), (x1-x2).pow(2).mean())

源代码

均方根误差

 rmse (preds, targs)

均方根误差

x1,x2 = torch.randn(20,5),torch.randn(20,5)
test_eq(compute_val(rmse, x1, x2), torch.sqrt(F.mse_loss(x1,x2)))

源代码

平均绝对误差

 mae (inp, targ)

计算 inptarg 之间的平均绝对误差。

x1,x2 = torch.randn(4,5),torch.randn(4,5)
test_eq(mae(x1,x2), torch.abs(x1-x2).mean())

源代码

均方对数误差

 msle (inp, targ)

计算 inptarg 之间的均方对数误差。

x1,x2 = torch.randn(4,5),torch.randn(4,5)
x1,x2 = torch.relu(x1),torch.relu(x2)
test_close(msle(x1,x2), (torch.log(x1+1)-torch.log(x2+1)).pow(2).mean())

源代码

指数均方根百分比误差

 exp_rmspe (preds, targs)

预测和目标的指数的均方根百分比误差

x1,x2 = torch.randn(20,5),torch.randn(20,5)
test_eq(compute_val(exp_rmspe, x1, x2), torch.sqrt((((torch.exp(x2) - torch.exp(x1))/torch.exp(x2))**2).mean()))

源代码

解释方差

 ExplainedVariance (sample_weight=None)

预测和目标之间的解释方差

更多详细信息请参阅scikit-learn 文档


源代码

R平方得分

 R2Score (sample_weight=None)

预测和目标之间的 R平方得分

更多详细信息请参阅scikit-learn 文档


源代码

Pearson 相关系数

 PearsonCorrCoef (dim_argmax=None, activation='no', thresh=None,
                  to_np=False, invert_arg=False, flatten=True, name=None)

回归问题的 Pearson 相关系数

更多详细信息请参阅scipy 文档

x = torch.randint(-999, 999,(20,))
y = torch.randint(-999, 999,(20,))
test_eq(compute_val(PearsonCorrCoef(), x, y), scs.pearsonr(x.view(-1), y.view(-1))[0])

源代码

Spearman 相关系数

 SpearmanCorrCoef (dim_argmax=None, axis=0, nan_policy='propagate',
                   activation='no', thresh=None, to_np=False,
                   invert_arg=False, flatten=True, name=None)

回归问题的 Spearman 相关系数

更多详细信息请参阅scipy 文档

x = torch.randint(-999, 999,(20,))
y = torch.randint(-999, 999,(20,))
test_eq(compute_val(SpearmanCorrCoef(), x, y), scs.spearmanr(x.view(-1), y.view(-1))[0])

分割

from fastai.vision.all import *
model = resnet34()
x = cast(torch.rand(1,3,128,128), TensorImage)
type(model(x))
fastai.torch_core.TensorImage

源代码

前景准确率

 foreground_acc (inp, targ, bkg_idx=0, axis=1)

计算多类别分割中的非背景准确率

x = cast(torch.randn(4,5,3,3), TensorImage)
y = cast(x, TensorMask).argmax(dim=1)[:,None]
test_eq(foreground_acc(x,y), 1)
y[0] = 0 #the 0s are ignored so we get the same value
test_eq(foreground_acc(x,y), 1)

源代码

Dice 系数

 Dice (axis=1)

分割任务中二分类目标的 Dice 系数指标

x1 = cast(torch.randn(20,2,3,3), TensorImage)
x2 = cast(torch.randint(0, 2, (20, 3, 3)), TensorMask)
pred = x1.argmax(1)
inter = (pred*x2).float().sum().item()
union = (pred+x2).float().sum().item()
test_eq(compute_val(Dice(), x1, x2), 2*inter/union)

源代码

多类别 Dice 系数

 DiceMulti (axis=1)

分割任务中多类别目标的平均 Dice 指标 (Macro F1)

DiceMulti 方法实现了此论文中描述的“平均 F1:调和平均的算术平均”: https://arxiv.org/pdf/1911.03347.pdf

x1a = torch.ones(20,1,1,1)
x1b = torch.clone(x1a)*0.5
x1c = torch.clone(x1a)*0.3
x1 = torch.cat((x1a,x1b,x1c),dim=1)   # Prediction: 20xClass0
x2 = torch.zeros(20,1,1)              # Target: 20xClass0
test_eq(compute_val(DiceMulti(), x1, x2), 1.)

x2 = torch.ones(20,1,1)               # Target: 20xClass1
test_eq(compute_val(DiceMulti(), x1, x2), 0.)

x2a = torch.zeros(10,1,1)
x2b = torch.ones(5,1,1)
x2c = torch.ones(5,1,1) * 2
x2 = torch.cat((x2a,x2b,x2c),dim=0)   # Target: 10xClass0, 5xClass1, 5xClass2
dice1 = (2*10)/(2*10+10)              # Dice: 2*TP/(2*TP+FP+FN)
dice2 = 0
dice3 = 0
test_eq(compute_val(DiceMulti(), x1, x2), (dice1+dice2+dice3)/3)

源代码

Jaccard 系数

 JaccardCoeff (axis=1)

更节省 RAM 的 Jaccard 系数实现

x1 = cast(torch.randn(20,2,3,3), TensorImage)
x2 = cast(torch.randint(0, 2, (20, 3, 3)), TensorMask)
pred = x1.argmax(1)
inter = (pred*x2).float().sum().item()
union = (pred+x2).float().sum().item()
test_eq(compute_val(JaccardCoeff(), x1, x2), inter/(union-inter))

源代码

多类别 Jaccard 系数

 JaccardCoeffMulti (axis=1)

分割任务中多类别目标的平均 Jaccard 系数指标 (mIoU)

x1a = torch.ones(20,1,1,1)
x1b = torch.clone(x1a)*0.5
x1c = torch.clone(x1a)*0.3
x1 = torch.cat((x1a,x1b,x1c), dim=1)   # Prediction: 20xClass0
x2 = torch.zeros(20,1,1)              # Target: 20xClass0
test_eq(compute_val(JaccardCoeffMulti(), x1, x2), 1.)

x2 = torch.ones(20,1,1)               # Target: 20xClass1
test_eq(compute_val(JaccardCoeffMulti(), x1, x2), 0.)

x2a = torch.zeros(10,1,1)
x2b = torch.ones(5,1,1)
x2c = torch.ones(5,1,1) * 2
x2 = torch.cat((x2a,x2b,x2c), dim=0)   # Target: 10xClass0, 5xClass1, 5xClass2
jcrd1 = 10/(10+10)              # Jaccard: TP/(TP+FP+FN)
jcrd2 = 0
jcrd3 = 0
test_eq(compute_val(JaccardCoeffMulti(), x1, x2), (jcrd1+jcrd2+jcrd3)/3)

NLP


源代码

语料库 BLEU 指标

 CorpusBLEUMetric (vocab_sz=5000, axis=-1)

定义指标的蓝图

def create_vcb_emb(pred, targ):
    # create vocab "embedding" for predictions
    vcb_sz = max(torch.unique(torch.cat([pred, targ])))+1
    pred_emb=torch.zeros(pred.size()[0], pred.size()[1] ,vcb_sz)
    for i,v in enumerate(pred):
        pred_emb[i].scatter_(1, v.view(len(v),1),1)
    return pred_emb

def compute_bleu_val(met, x1, x2):
    met.reset()
    learn = TstLearner()
    learn.training=False    
    for i in range(len(x1)): 
        learn.pred,learn.yb = x1, (x2,)
        met.accumulate(learn)
    return met.value

targ = torch.tensor([[1,2,3,4,5,6,1,7,8]]) 
pred = torch.tensor([[1,9,3,4,5,6,1,10,8]])
pred_emb = create_vcb_emb(pred, targ)
test_close(compute_bleu_val(CorpusBLEUMetric(), pred_emb, targ), 0.48549)

targ = torch.tensor([[1,2,3,4,5,6,1,7,8],[1,2,3,4,5,6,1,7,8]]) 
pred = torch.tensor([[1,9,3,4,5,6,1,10,8],[1,9,3,4,5,6,1,10,8]])
pred_emb = create_vcb_emb(pred, targ)
test_close(compute_bleu_val(CorpusBLEUMetric(), pred_emb, targ), 0.48549)

BLEU 指标在此文章中引入,用于评估翻译模型的性能。它基于预测中 n-grams 与目标的精确度。有关 BLEU 的更详细说明,请参阅fastai NLP 课程的 BLEU notebook

精确率计算中使用的平滑方法与SacreBLEU中的方法相同,该方法源自 Chen & Cherry 于 2014 年发表的论文中的“方法 3”。


源代码

困惑度

 Perplexity ()

语言模型的困惑度(交叉熵损失的指数)

x1,x2 = torch.randn(20,5),torch.randint(0, 5, (20,))
tst = perplexity
tst.reset()
vals = [0,6,15,20]
learn = TstLearner()
for i in range(3): 
    learn.yb = (x2[vals[i]:vals[i+1]],)
    learn.loss = F.cross_entropy(x1[vals[i]:vals[i+1]],x2[vals[i]:vals[i+1]])
    tst.accumulate(learn)
test_close(tst.value, torch.exp(F.cross_entropy(x1,x2)))

源代码

损失指标

 LossMetric (attr, nm=None)

loss_func.attr 创建名为 nm 的指标


源代码

损失指标集

 LossMetrics (attrs, nms=None)

针对 attrsnms 中每个元素的 LossMetric 列表

class CombineL1L2(Module):
    def forward(self, out, targ):
        self.l1 = F.l1_loss(out, targ)
        self.l2 = F.mse_loss(out, targ)
        return self.l1+self.l2
learn = synth_learner(metrics=LossMetrics('l1,l2'))
learn.loss_func = CombineL1L2()
learn.fit(2)
轮次 训练损失 验证损失 l1 l2 时间
0 15.296746 12.515826 3.019884 9.495943 00:00
1 13.290909 8.719325 2.454751 6.264574 00:00