准备数据用于模型训练的主要类是 TabularDataLoaders 及其工厂方法。请查阅表格数据教程以获取使用示例。
源代码
TabularDataLoaders
TabularDataLoaders (*loaders, path:str|pathlib.Path='.', device=None)
多个 DataLoader 的基本封装,带有用于表格数据的工厂方法
| loaders |
VAR_POSITIONAL |
|
要封装的 DataLoader 对象 |
| path |
str | pathlib.Path |
. |
导出对象的存储路径 |
| device |
NoneType |
None |
放置 DataLoaders 的设备 |
不应直接使用此类别,应优先使用其中一种工厂方法。所有这些工厂方法都接受以下参数
cat_names: 分类变量的名称
cont_names: 连续变量的名称
y_names: 因变量的名称
y_block: 用于目标的 TransformBlock
valid_idx: 用于验证集的索引(否则默认为随机分割)
bs: 批处理大小
val_bs: 用于验证 DataLoader 的批处理大小(默认为 bs)
shuffle_train: 是否打乱训练 DataLoader
n: 覆盖数据集中元素的数量
device: 要使用的 PyTorch 设备(默认为 default_device())
源代码
TabularDataLoaders.from_df
TabularDataLoaders.from_df (df:pd.DataFrame, path:str|Path='.',
procs:list=None, cat_names:list=None,
cont_names:list=None, y_names:list=None,
y_block:TransformBlock=None,
valid_idx:list=None, bs:int=64,
shuffle_train:bool=None, shuffle:bool=True,
val_shuffle:bool=False, n:int=None,
device:torch.device=None,
drop_last:bool=None, val_bs:int=None)
使用 procs 从 path 中的 df 创建 TabularDataLoaders
| df |
pd.DataFrame |
|
|
| path |
str | Path |
. |
df 的位置,默认为当前工作目录 |
| procs |
list |
None |
TabularProc 列表 |
| cat_names |
list |
None |
与分类变量相关的列名 |
| cont_names |
list |
None |
与连续变量相关的列名 |
| y_names |
list |
None |
因变量的名称 |
| y_block |
TransformBlock |
None |
用于目标(s)的 TransformBlock |
| valid_idx |
list |
None |
用于验证集的索引列表,默认为随机分割 |
| bs |
int |
64 |
批处理大小 |
| shuffle_train |
bool |
None |
(已弃用,请使用 shuffle)打乱训练 DataLoader |
| shuffle |
bool |
True |
打乱训练 DataLoader |
| val_shuffle |
bool |
False |
打乱验证 DataLoader |
| n |
int |
None |
用于创建 DataLoader 的 Datasets 大小 |
| device |
device |
None |
放置 DataLoaders 的设备 |
| drop_last |
bool |
None |
丢弃最后一个不完整的批次,默认为 shuffle |
| val_bs |
int |
None |
验证批处理大小,默认为 bs |
让我们看看使用 adult 数据集的示例
path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv', skipinitialspace=True)
df.head()
| 0 |
49 |
Private |
101320 |
Assoc-acdm |
12.0 |
Married-civ-spouse |
NaN |
Wife |
White |
Female |
0 |
1902 |
40 |
United-States |
>=50k |
| 1 |
44 |
Private |
236746 |
Masters |
14.0 |
Divorced |
Exec-managerial |
Not-in-family |
White |
Male |
10520 |
0 |
45 |
United-States |
>=50k |
| 2 |
38 |
Private |
96185 |
HS-grad |
NaN |
Divorced |
NaN |
Unmarried |
Black |
Female |
0 |
0 |
32 |
United-States |
<50k |
| 3 |
38 |
Self-emp-inc |
112847 |
Prof-school |
15.0 |
Married-civ-spouse |
Prof-specialty |
Husband |
Asian-Pac-Islander |
Male |
0 |
0 |
40 |
United-States |
>=50k |
| 4 |
42 |
Self-emp-not-inc |
82297 |
7th-8th |
NaN |
Married-civ-spouse |
Other-service |
Wife |
Black |
Female |
0 |
0 |
50 |
United-States |
<50k |
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
dls = TabularDataLoaders.from_df(df, path, procs=procs, cat_names=cat_names, cont_names=cont_names,
y_names="salary", valid_idx=list(range(800,1000)), bs=64)
| 0 |
Private |
HS-grad |
Married-civ-spouse |
Adm-clerical |
Husband |
White |
False |
24.0 |
121312.998272 |
9.0 |
<50k |
| 1 |
Private |
HS-grad |
Never-married |
Other-service |
Not-in-family |
White |
False |
19.0 |
198320.000325 |
9.0 |
<50k |
| 2 |
Private |
Bachelors |
Married-civ-spouse |
Sales |
Husband |
White |
False |
66.0 |
169803.999308 |
13.0 |
>=50k |
| 3 |
Private |
HS-grad |
Divorced |
Adm-clerical |
Unmarried |
White |
False |
40.0 |
799280.980929 |
9.0 |
<50k |
| 4 |
Local-gov |
10th |
Never-married |
Other-service |
Own-child |
White |
False |
18.0 |
55658.003629 |
6.0 |
<50k |
| 5 |
Private |
HS-grad |
Never-married |
Handlers-cleaners |
Other-relative |
White |
False |
30.0 |
375827.003847 |
9.0 |
<50k |
| 6 |
Private |
Some-college |
Never-married |
Handlers-cleaners |
Own-child |
White |
False |
20.0 |
173723.999335 |
10.0 |
<50k |
| 7 |
? |
Some-college |
Never-married |
? |
Own-child |
White |
False |
21.0 |
107800.997986 |
10.0 |
<50k |
| 8 |
Private |
HS-grad |
Never-married |
Handlers-cleaners |
Own-child |
White |
False |
19.0 |
263338.000072 |
9.0 |
<50k |
| 9 |
Private |
Some-college |
Married-civ-spouse |
Tech-support |
Husband |
White |
False |
35.0 |
194590.999986 |
10.0 |
<50k |
源代码
TabularDataLoaders.from_csv
TabularDataLoaders.from_csv (csv:str|Path|io.BufferedReader,
skipinitialspace:bool=True,
path:str|Path='.', procs:list=None,
cat_names:list=None, cont_names:list=None,
y_names:list=None,
y_block:TransformBlock=None,
valid_idx:list=None, bs:int=64,
shuffle_train:bool=None, shuffle:bool=True,
val_shuffle:bool=False, n:int=None,
device:torch.device=None,
drop_last:bool=None, val_bs:int=None)
使用 procs 从 path 中的 csv 文件创建 TabularDataLoaders
| csv |
str | Path | io.BufferedReader |
|
训练数据的 csv 文件 |
| skipinitialspace |
bool |
True |
跳过分隔符后的空格 |
| path |
str | Path |
. |
df 的位置,默认为当前工作目录 |
| procs |
list |
None |
TabularProc 列表 |
| cat_names |
list |
None |
与分类变量相关的列名 |
| cont_names |
list |
None |
与连续变量相关的列名 |
| y_names |
list |
None |
因变量的名称 |
| y_block |
TransformBlock |
None |
用于目标(s)的 TransformBlock |
| valid_idx |
list |
None |
用于验证集的索引列表,默认为随机分割 |
| bs |
int |
64 |
批处理大小 |
| shuffle_train |
bool |
None |
(已弃用,请使用 shuffle)打乱训练 DataLoader |
| shuffle |
bool |
True |
打乱训练 DataLoader |
| val_shuffle |
bool |
False |
打乱验证 DataLoader |
| n |
int |
None |
用于创建 DataLoader 的 Datasets 大小 |
| device |
device |
None |
放置 DataLoaders 的设备 |
| drop_last |
bool |
None |
丢弃最后一个不完整的批次,默认为 shuffle |
| val_bs |
int |
None |
验证批处理大小,默认为 bs |
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
dls = TabularDataLoaders.from_csv(path/'adult.csv', path=path, procs=procs, cat_names=cat_names, cont_names=cont_names,
y_names="salary", valid_idx=list(range(800,1000)), bs=64)
源代码
TabularDataLoaders.test_dl
TabularDataLoaders.test_dl (test_items, rm_type_tfms=None,
process:bool=True, inplace:bool=False, bs=16,
shuffle=False, after_batch=None,
num_workers=0, verbose:bool=False,
do_setup:bool=True, pin_memory=False,
timeout=0, batch_size=None, drop_last=False,
indexed=None, n=None, device=None,
persistent_workers=False,
pin_memory_device='', wif=None,
before_iter=None, after_item=None,
before_batch=None, after_iter=None,
create_batches=None, create_item=None,
create_batch=None, retain=None,
get_idxs=None, sample=None, shuffle_fn=None,
do_batch=None)
使用验证 procs 从 test_items 创建测试 TabDataLoader
| test_items |
|
|
用于创建新的测试 TabDataLoader 的条目,格式与训练数据相同 |
| rm_type_tfms |
NoneType |
None |
从 procs 中移除的 Transform 数量 |
| process |
bool |
True |
立即将验证 TabularProc 应用于 test_items |
| inplace |
bool |
False |
如果为 False,则在内存中保留原始 test_items 的单独副本 |
| bs |
int |
64 |
批次大小 |
| shuffle |
bool |
False |
是否打乱数据 |
| after_batch |
NoneType |
None |
|
| num_workers |
int |
None |
并行使用的 CPU 核数(默认值:最多 16 个可用核) |
| verbose |
bool |
False |
是否打印详细日志 |
| do_setup |
bool |
True |
是否对批次变换运行 setup() |
| pin_memory |
bool |
False |
|
| timeout |
int |
0 |
|
| batch_size |
NoneType |
None |
|
| drop_last |
bool |
False |
|
| indexed |
NoneType |
None |
|
| n |
NoneType |
None |
|
| device |
NoneType |
None |
|
| persistent_workers |
bool |
False |
|
| pin_memory_device |
str |
|
|
| wif |
NoneType |
None |
|
| before_iter |
NoneType |
None |
|
| after_item |
NoneType |
None |
|
| before_batch |
NoneType |
None |
|
| after_iter |
NoneType |
None |
|
| create_batches |
NoneType |
None |
|
| create_item |
NoneType |
None |
|
| create_batch |
NoneType |
None |
|
| retain |
NoneType |
None |
|
| get_idxs |
NoneType |
None |
|
| sample |
NoneType |
None |
|
| shuffle_fn |
NoneType |
None |
|
| do_batch |
NoneType |
None |
|
外部结构化数据文件可能包含意外的空格,例如在逗号之后。我们可以在 adult.csv 的第一行看到这种情况:"49, Private,101320, ..."。通常需要进行修剪。Pandas 有一个方便的参数 skipinitialspace,它在 TabularDataLoaders.from_csv() 中得以暴露。否则,以后用于推断的类别标签,例如 workclass:Private,如果训练标签被读取为 " Private",则会被错误地归类为 0 或 "#na#"。让我们测试这个功能。
test_data = {
'age': [49],
'workclass': ['Private'],
'fnlwgt': [101320],
'education': ['Assoc-acdm'],
'education-num': [12.0],
'marital-status': ['Married-civ-spouse'],
'occupation': [''],
'relationship': ['Wife'],
'race': ['White'],
}
input = pd.DataFrame(test_data)
tdl = dls.test_dl(input)
test_ne(0, tdl.dataset.iloc[0]['workclass'])