Wandb
与 Weights & Biases 集成
首先,您需要安装 wandb:
pip install wandb
创建一个免费账户,然后运行
wandb login
在终端中。按照链接获取您需要粘贴的 API token,然后您就可以开始了!
WandbCallback
WandbCallback (log:str=None, log_preds:bool=True, log_preds_every_epoch:bool=False, log_model:bool=False, model_name:str=None, log_dataset:bool=False, dataset_name:str=None, valid_dl:fastai.data.core.TfmdDL=None, n_preds:int=36, seed:int=12345, reorder=True)
保存模型拓扑结构、损失和指标
类型 | 默认值 | 详情 | |
---|---|---|---|
log | str | None | 要记录什么(可以是 gradients , parameters , all 或 None) |
log_preds | bool | True | 是否将模型预测记录到 wandb.Table 中 |
log_preds_every_epoch | bool | False | 是在每个 epoch 记录预测还是在训练结束时记录 |
log_model | bool | False | 是否将模型检查点保存到 wandb.Artifact 中 |
model_name | str | None | 要保存的 model_name 名称,会覆盖 SaveModelCallback |
log_dataset | bool | False | 是否将数据集记录到 wandb.Artifact 中 |
dataset_name | str | None | 用于记录数据集的名称 |
valid_dl | TfmdDL | None | 如果 log_preds=True ,样本将从 valid_dl 中抽取 |
n_preds | int | 36 | 记录多少个预测样本 |
seed | int | 12345 | 抽取样本的种子 |
reorder | bool | True |
可选地根据 log
参数(可以是 “gradients”, “parameters”, “all” 或 None)记录权重和/或梯度,如果 log_preds=True
则记录样本预测(这些预测来自 valid_dl
或验证集的随机样本,由 seed
确定)。在这种情况下,会记录 n_preds
个样本。
如果与 SaveModelCallback
结合使用,也会保存最佳模型(可以通过设置 log_model=False
来禁用)。
数据集也可以被跟踪
- 如果
log_dataset
为True
,则跟踪的文件夹从learn.dls.path
中获取 log_dataset
可以明确设置为要跟踪的文件夹- 数据集的名称可以通过
dataset_name
明确指定,否则默认为文件夹名称 - 注意:子文件夹“models”总是会被忽略
对于自定义场景,您还可以手动使用 log_dataset
和 log_model
函数分别记录您自己的数据集和模型。
Learner.gather_args
Learner.gather_args ()
收集学习器可访问的配置参数
Learner.gather_args
Learner.gather_args ()
收集学习器可访问的配置参数
log_dataset
log_dataset (path, name=None, metadata={}, description='raw dataset')
记录数据集文件夹
log_model
log_model (path, name=None, metadata={}, description='trained model')
记录模型文件
使用示例
定义好您的 Learner
后,在调用 fit
或 fit_one_cycle
之前,您需要初始化 wandb
import wandb
wandb.init()
如果您想在没有账户的情况下使用 Weights & Biases,可以调用 wandb.init(anonymous='allow')
。
然后将回调函数添加到您的 learner
或 fit
方法调用中,如果您想保存最佳模型,可能还需要结合使用 SaveModelCallback
from fastai.callback.wandb import *
# To log only during one training phase
learn.fit(..., cbs=WandbCallback())
# To log continuously for all training phases
learn = learner(..., cbs=WandbCallback())
数据集和模型可以通过回调函数进行跟踪,也可以直接通过 log_model
和 log_dataset
函数进行跟踪。
更多详情,请参考 W&B 文档。