准备数据用于模型训练的主要类是 TabularDataLoaders
及其工厂方法。请查阅表格数据教程以获取使用示例。
源代码
TabularDataLoaders
TabularDataLoaders (*loaders, path:str|pathlib.Path='.', device=None)
多个 DataLoader
的基本封装,带有用于表格数据的工厂方法
loaders |
VAR_POSITIONAL |
|
要封装的 DataLoader 对象 |
path |
str | pathlib.Path |
. |
导出对象的存储路径 |
device |
NoneType |
None |
放置 DataLoaders 的设备 |
不应直接使用此类别,应优先使用其中一种工厂方法。所有这些工厂方法都接受以下参数
cat_names
: 分类变量的名称
cont_names
: 连续变量的名称
y_names
: 因变量的名称
y_block
: 用于目标的 TransformBlock
valid_idx
: 用于验证集的索引(否则默认为随机分割)
bs
: 批处理大小
val_bs
: 用于验证 DataLoader
的批处理大小(默认为 bs
)
shuffle_train
: 是否打乱训练 DataLoader
n
: 覆盖数据集中元素的数量
device
: 要使用的 PyTorch 设备(默认为 default_device()
)
源代码
TabularDataLoaders.from_df
TabularDataLoaders.from_df (df:pd.DataFrame, path:str|Path='.',
procs:list=None, cat_names:list=None,
cont_names:list=None, y_names:list=None,
y_block:TransformBlock=None,
valid_idx:list=None, bs:int=64,
shuffle_train:bool=None, shuffle:bool=True,
val_shuffle:bool=False, n:int=None,
device:torch.device=None,
drop_last:bool=None, val_bs:int=None)
使用 procs
从 path
中的 df
创建 TabularDataLoaders
df |
pd.DataFrame |
|
|
path |
str | Path |
. |
df 的位置,默认为当前工作目录 |
procs |
list |
None |
TabularProc 列表 |
cat_names |
list |
None |
与分类变量相关的列名 |
cont_names |
list |
None |
与连续变量相关的列名 |
y_names |
list |
None |
因变量的名称 |
y_block |
TransformBlock |
None |
用于目标(s)的 TransformBlock |
valid_idx |
list |
None |
用于验证集的索引列表,默认为随机分割 |
bs |
int |
64 |
批处理大小 |
shuffle_train |
bool |
None |
(已弃用,请使用 shuffle )打乱训练 DataLoader |
shuffle |
bool |
True |
打乱训练 DataLoader |
val_shuffle |
bool |
False |
打乱验证 DataLoader |
n |
int |
None |
用于创建 DataLoader 的 Datasets 大小 |
device |
device |
None |
放置 DataLoaders 的设备 |
drop_last |
bool |
None |
丢弃最后一个不完整的批次,默认为 shuffle |
val_bs |
int |
None |
验证批处理大小,默认为 bs |
让我们看看使用 adult 数据集的示例
path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv', skipinitialspace=True)
df.head()
0 |
49 |
Private |
101320 |
Assoc-acdm |
12.0 |
Married-civ-spouse |
NaN |
Wife |
White |
Female |
0 |
1902 |
40 |
United-States |
>=50k |
1 |
44 |
Private |
236746 |
Masters |
14.0 |
Divorced |
Exec-managerial |
Not-in-family |
White |
Male |
10520 |
0 |
45 |
United-States |
>=50k |
2 |
38 |
Private |
96185 |
HS-grad |
NaN |
Divorced |
NaN |
Unmarried |
Black |
Female |
0 |
0 |
32 |
United-States |
<50k |
3 |
38 |
Self-emp-inc |
112847 |
Prof-school |
15.0 |
Married-civ-spouse |
Prof-specialty |
Husband |
Asian-Pac-Islander |
Male |
0 |
0 |
40 |
United-States |
>=50k |
4 |
42 |
Self-emp-not-inc |
82297 |
7th-8th |
NaN |
Married-civ-spouse |
Other-service |
Wife |
Black |
Female |
0 |
0 |
50 |
United-States |
<50k |
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
dls = TabularDataLoaders.from_df(df, path, procs=procs, cat_names=cat_names, cont_names=cont_names,
y_names="salary", valid_idx=list(range(800,1000)), bs=64)
0 |
Private |
HS-grad |
Married-civ-spouse |
Adm-clerical |
Husband |
White |
False |
24.0 |
121312.998272 |
9.0 |
<50k |
1 |
Private |
HS-grad |
Never-married |
Other-service |
Not-in-family |
White |
False |
19.0 |
198320.000325 |
9.0 |
<50k |
2 |
Private |
Bachelors |
Married-civ-spouse |
Sales |
Husband |
White |
False |
66.0 |
169803.999308 |
13.0 |
>=50k |
3 |
Private |
HS-grad |
Divorced |
Adm-clerical |
Unmarried |
White |
False |
40.0 |
799280.980929 |
9.0 |
<50k |
4 |
Local-gov |
10th |
Never-married |
Other-service |
Own-child |
White |
False |
18.0 |
55658.003629 |
6.0 |
<50k |
5 |
Private |
HS-grad |
Never-married |
Handlers-cleaners |
Other-relative |
White |
False |
30.0 |
375827.003847 |
9.0 |
<50k |
6 |
Private |
Some-college |
Never-married |
Handlers-cleaners |
Own-child |
White |
False |
20.0 |
173723.999335 |
10.0 |
<50k |
7 |
? |
Some-college |
Never-married |
? |
Own-child |
White |
False |
21.0 |
107800.997986 |
10.0 |
<50k |
8 |
Private |
HS-grad |
Never-married |
Handlers-cleaners |
Own-child |
White |
False |
19.0 |
263338.000072 |
9.0 |
<50k |
9 |
Private |
Some-college |
Married-civ-spouse |
Tech-support |
Husband |
White |
False |
35.0 |
194590.999986 |
10.0 |
<50k |
源代码
TabularDataLoaders.from_csv
TabularDataLoaders.from_csv (csv:str|Path|io.BufferedReader,
skipinitialspace:bool=True,
path:str|Path='.', procs:list=None,
cat_names:list=None, cont_names:list=None,
y_names:list=None,
y_block:TransformBlock=None,
valid_idx:list=None, bs:int=64,
shuffle_train:bool=None, shuffle:bool=True,
val_shuffle:bool=False, n:int=None,
device:torch.device=None,
drop_last:bool=None, val_bs:int=None)
使用 procs
从 path
中的 csv
文件创建 TabularDataLoaders
csv |
str | Path | io.BufferedReader |
|
训练数据的 csv 文件 |
skipinitialspace |
bool |
True |
跳过分隔符后的空格 |
path |
str | Path |
. |
df 的位置,默认为当前工作目录 |
procs |
list |
None |
TabularProc 列表 |
cat_names |
list |
None |
与分类变量相关的列名 |
cont_names |
list |
None |
与连续变量相关的列名 |
y_names |
list |
None |
因变量的名称 |
y_block |
TransformBlock |
None |
用于目标(s)的 TransformBlock |
valid_idx |
list |
None |
用于验证集的索引列表,默认为随机分割 |
bs |
int |
64 |
批处理大小 |
shuffle_train |
bool |
None |
(已弃用,请使用 shuffle )打乱训练 DataLoader |
shuffle |
bool |
True |
打乱训练 DataLoader |
val_shuffle |
bool |
False |
打乱验证 DataLoader |
n |
int |
None |
用于创建 DataLoader 的 Datasets 大小 |
device |
device |
None |
放置 DataLoaders 的设备 |
drop_last |
bool |
None |
丢弃最后一个不完整的批次,默认为 shuffle |
val_bs |
int |
None |
验证批处理大小,默认为 bs |
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
dls = TabularDataLoaders.from_csv(path/'adult.csv', path=path, procs=procs, cat_names=cat_names, cont_names=cont_names,
y_names="salary", valid_idx=list(range(800,1000)), bs=64)
源代码
TabularDataLoaders.test_dl
TabularDataLoaders.test_dl (test_items, rm_type_tfms=None,
process:bool=True, inplace:bool=False, bs=16,
shuffle=False, after_batch=None,
num_workers=0, verbose:bool=False,
do_setup:bool=True, pin_memory=False,
timeout=0, batch_size=None, drop_last=False,
indexed=None, n=None, device=None,
persistent_workers=False,
pin_memory_device='', wif=None,
before_iter=None, after_item=None,
before_batch=None, after_iter=None,
create_batches=None, create_item=None,
create_batch=None, retain=None,
get_idxs=None, sample=None, shuffle_fn=None,
do_batch=None)
使用验证 procs
从 test_items
创建测试 TabDataLoader
test_items |
|
|
用于创建新的测试 TabDataLoader 的条目,格式与训练数据相同 |
rm_type_tfms |
NoneType |
None |
从 procs 中移除的 Transform 数量 |
process |
bool |
True |
立即将验证 TabularProc 应用于 test_items |
inplace |
bool |
False |
如果为 False ,则在内存中保留原始 test_items 的单独副本 |
bs |
int |
64 |
批次大小 |
shuffle |
bool |
False |
是否打乱数据 |
after_batch |
NoneType |
None |
|
num_workers |
int |
None |
并行使用的 CPU 核数(默认值:最多 16 个可用核) |
verbose |
bool |
False |
是否打印详细日志 |
do_setup |
bool |
True |
是否对批次变换运行 setup() |
pin_memory |
bool |
False |
|
timeout |
int |
0 |
|
batch_size |
NoneType |
None |
|
drop_last |
bool |
False |
|
indexed |
NoneType |
None |
|
n |
NoneType |
None |
|
device |
NoneType |
None |
|
persistent_workers |
bool |
False |
|
pin_memory_device |
str |
|
|
wif |
NoneType |
None |
|
before_iter |
NoneType |
None |
|
after_item |
NoneType |
None |
|
before_batch |
NoneType |
None |
|
after_iter |
NoneType |
None |
|
create_batches |
NoneType |
None |
|
create_item |
NoneType |
None |
|
create_batch |
NoneType |
None |
|
retain |
NoneType |
None |
|
get_idxs |
NoneType |
None |
|
sample |
NoneType |
None |
|
shuffle_fn |
NoneType |
None |
|
do_batch |
NoneType |
None |
|
外部结构化数据文件可能包含意外的空格,例如在逗号之后。我们可以在 adult.csv 的第一行看到这种情况:"49, Private,101320, ..."
。通常需要进行修剪。Pandas 有一个方便的参数 skipinitialspace
,它在 TabularDataLoaders.from_csv()
中得以暴露。否则,以后用于推断的类别标签,例如 workclass
:Private
,如果训练标签被读取为 " Private"
,则会被错误地归类为 0 或 "#na#"
。让我们测试这个功能。
test_data = {
'age': [49],
'workclass': ['Private'],
'fnlwgt': [101320],
'education': ['Assoc-acdm'],
'education-num': [12.0],
'marital-status': ['Married-civ-spouse'],
'occupation': [''],
'relationship': ['Wife'],
'race': ['White'],
}
input = pd.DataFrame(test_data)
tdl = dls.test_dl(input)
test_ne(0, tdl.dataset.iloc[0]['workclass'])