Wandb

Weights & Biases 集成

首先,您需要安装 wandb:

pip install wandb

创建一个免费账户,然后运行

wandb login

在终端中。按照链接获取您需要粘贴的 API token,然后您就可以开始了!


源码

WandbCallback

 WandbCallback (log:str=None, log_preds:bool=True,
                log_preds_every_epoch:bool=False, log_model:bool=False,
                model_name:str=None, log_dataset:bool=False,
                dataset_name:str=None,
                valid_dl:fastai.data.core.TfmdDL=None, n_preds:int=36,
                seed:int=12345, reorder=True)

保存模型拓扑结构、损失和指标

类型 默认值 详情
log str None 要记录什么(可以是 gradients, parameters, all 或 None)
log_preds bool True 是否将模型预测记录到 wandb.Table
log_preds_every_epoch bool False 是在每个 epoch 记录预测还是在训练结束时记录
log_model bool False 是否将模型检查点保存到 wandb.Artifact
model_name str None 要保存的 model_name 名称,会覆盖 SaveModelCallback
log_dataset bool False 是否将数据集记录到 wandb.Artifact
dataset_name str None 用于记录数据集的名称
valid_dl TfmdDL None 如果 log_preds=True,样本将从 valid_dl 中抽取
n_preds int 36 记录多少个预测样本
seed int 12345 抽取样本的种子
reorder bool True

可选地根据 log 参数(可以是 “gradients”, “parameters”, “all” 或 None)记录权重和/或梯度,如果 log_preds=True 则记录样本预测(这些预测来自 valid_dl 或验证集的随机样本,由 seed 确定)。在这种情况下,会记录 n_preds 个样本。

如果与 SaveModelCallback 结合使用,也会保存最佳模型(可以通过设置 log_model=False 来禁用)。

数据集也可以被跟踪

  • 如果 log_datasetTrue,则跟踪的文件夹从 learn.dls.path 中获取
  • log_dataset 可以明确设置为要跟踪的文件夹
  • 数据集的名称可以通过 dataset_name 明确指定,否则默认为文件夹名称
  • 注意:子文件夹“models”总是会被忽略

对于自定义场景,您还可以手动使用 log_datasetlog_model 函数分别记录您自己的数据集和模型。


源码

Learner.gather_args

 Learner.gather_args ()

收集学习器可访问的配置参数


源码

Learner.gather_args

 Learner.gather_args ()

收集学习器可访问的配置参数


源码

log_dataset

 log_dataset (path, name=None, metadata={}, description='raw dataset')

记录数据集文件夹


源码

log_model

 log_model (path, name=None, metadata={}, description='trained model')

记录模型文件

使用示例

定义好您的 Learner 后,在调用 fitfit_one_cycle 之前,您需要初始化 wandb

import wandb
wandb.init()

如果您想在没有账户的情况下使用 Weights & Biases,可以调用 wandb.init(anonymous='allow')

然后将回调函数添加到您的 learnerfit 方法调用中,如果您想保存最佳模型,可能还需要结合使用 SaveModelCallback

from fastai.callback.wandb import *

# To log only during one training phase
learn.fit(..., cbs=WandbCallback())

# To log continuously for all training phases
learn = learner(..., cbs=WandbCallback())

数据集和模型可以通过回调函数进行跟踪,也可以直接通过 log_modellog_dataset 函数进行跟踪。

更多详情,请参考 W&B 文档