表格数据

用于在表格应用中将数据转换为 DataLoaders 以及更高级别类 TabularDataLoaders 的辅助函数

准备数据用于模型训练的主要类是 TabularDataLoaders 及其工厂方法。请查阅表格数据教程以获取使用示例。


源代码

TabularDataLoaders

 TabularDataLoaders (*loaders, path:str|pathlib.Path='.', device=None)

多个 DataLoader 的基本封装,带有用于表格数据的工厂方法

类型 默认值 详细信息
loaders VAR_POSITIONAL 要封装的 DataLoader 对象
path str | pathlib.Path . 导出对象的存储路径
device NoneType None 放置 DataLoaders 的设备

不应直接使用此类别,应优先使用其中一种工厂方法。所有这些工厂方法都接受以下参数

  • cat_names: 分类变量的名称
  • cont_names: 连续变量的名称
  • y_names: 因变量的名称
  • y_block: 用于目标的 TransformBlock
  • valid_idx: 用于验证集的索引(否则默认为随机分割)
  • bs: 批处理大小
  • val_bs: 用于验证 DataLoader 的批处理大小(默认为 bs
  • shuffle_train: 是否打乱训练 DataLoader
  • n: 覆盖数据集中元素的数量
  • device: 要使用的 PyTorch 设备(默认为 default_device()

源代码

TabularDataLoaders.from_df

 TabularDataLoaders.from_df (df:pd.DataFrame, path:str|Path='.',
                             procs:list=None, cat_names:list=None,
                             cont_names:list=None, y_names:list=None,
                             y_block:TransformBlock=None,
                             valid_idx:list=None, bs:int=64,
                             shuffle_train:bool=None, shuffle:bool=True,
                             val_shuffle:bool=False, n:int=None,
                             device:torch.device=None,
                             drop_last:bool=None, val_bs:int=None)

使用 procspath 中的 df 创建 TabularDataLoaders

类型 默认值 详细信息
df pd.DataFrame
path str | Path . df 的位置,默认为当前工作目录
procs list None TabularProc 列表
cat_names list None 与分类变量相关的列名
cont_names list None 与连续变量相关的列名
y_names list None 因变量的名称
y_block TransformBlock None 用于目标(s)的 TransformBlock
valid_idx list None 用于验证集的索引列表,默认为随机分割
bs int 64 批处理大小
shuffle_train bool None (已弃用,请使用 shuffle)打乱训练 DataLoader
shuffle bool True 打乱训练 DataLoader
val_shuffle bool False 打乱验证 DataLoader
n int None 用于创建 DataLoaderDatasets 大小
device device None 放置 DataLoaders 的设备
drop_last bool None 丢弃最后一个不完整的批次,默认为 shuffle
val_bs int None 验证批处理大小,默认为 bs

让我们看看使用 adult 数据集的示例

path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv', skipinitialspace=True)
df.head()
age workclass fnlwgt education education-num marital-status occupation relationship race sex capital-gain capital-loss hours-per-week native-country salary
0 49 Private 101320 Assoc-acdm 12.0 Married-civ-spouse NaN Wife White Female 0 1902 40 United-States >=50k
1 44 Private 236746 Masters 14.0 Divorced Exec-managerial Not-in-family White Male 10520 0 45 United-States >=50k
2 38 Private 96185 HS-grad NaN Divorced NaN Unmarried Black Female 0 0 32 United-States <50k
3 38 Self-emp-inc 112847 Prof-school 15.0 Married-civ-spouse Prof-specialty Husband Asian-Pac-Islander Male 0 0 40 United-States >=50k
4 42 Self-emp-not-inc 82297 7th-8th NaN Married-civ-spouse Other-service Wife Black Female 0 0 50 United-States <50k
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
dls = TabularDataLoaders.from_df(df, path, procs=procs, cat_names=cat_names, cont_names=cont_names, 
                                 y_names="salary", valid_idx=list(range(800,1000)), bs=64)
dls.show_batch()
workclass education marital-status occupation relationship race education-num_na age fnlwgt education-num salary
0 Private HS-grad Married-civ-spouse Adm-clerical Husband White False 24.0 121312.998272 9.0 <50k
1 Private HS-grad Never-married Other-service Not-in-family White False 19.0 198320.000325 9.0 <50k
2 Private Bachelors Married-civ-spouse Sales Husband White False 66.0 169803.999308 13.0 >=50k
3 Private HS-grad Divorced Adm-clerical Unmarried White False 40.0 799280.980929 9.0 <50k
4 Local-gov 10th Never-married Other-service Own-child White False 18.0 55658.003629 6.0 <50k
5 Private HS-grad Never-married Handlers-cleaners Other-relative White False 30.0 375827.003847 9.0 <50k
6 Private Some-college Never-married Handlers-cleaners Own-child White False 20.0 173723.999335 10.0 <50k
7 ? Some-college Never-married ? Own-child White False 21.0 107800.997986 10.0 <50k
8 Private HS-grad Never-married Handlers-cleaners Own-child White False 19.0 263338.000072 9.0 <50k
9 Private Some-college Married-civ-spouse Tech-support Husband White False 35.0 194590.999986 10.0 <50k

源代码

TabularDataLoaders.from_csv

 TabularDataLoaders.from_csv (csv:str|Path|io.BufferedReader,
                              skipinitialspace:bool=True,
                              path:str|Path='.', procs:list=None,
                              cat_names:list=None, cont_names:list=None,
                              y_names:list=None,
                              y_block:TransformBlock=None,
                              valid_idx:list=None, bs:int=64,
                              shuffle_train:bool=None, shuffle:bool=True,
                              val_shuffle:bool=False, n:int=None,
                              device:torch.device=None,
                              drop_last:bool=None, val_bs:int=None)

使用 procspath 中的 csv 文件创建 TabularDataLoaders

类型 默认值 详细信息
csv str | Path | io.BufferedReader 训练数据的 csv 文件
skipinitialspace bool True 跳过分隔符后的空格
path str | Path . df 的位置,默认为当前工作目录
procs list None TabularProc 列表
cat_names list None 与分类变量相关的列名
cont_names list None 与连续变量相关的列名
y_names list None 因变量的名称
y_block TransformBlock None 用于目标(s)的 TransformBlock
valid_idx list None 用于验证集的索引列表,默认为随机分割
bs int 64 批处理大小
shuffle_train bool None (已弃用,请使用 shuffle)打乱训练 DataLoader
shuffle bool True 打乱训练 DataLoader
val_shuffle bool False 打乱验证 DataLoader
n int None 用于创建 DataLoaderDatasets 大小
device device None 放置 DataLoaders 的设备
drop_last bool None 丢弃最后一个不完整的批次,默认为 shuffle
val_bs int None 验证批处理大小,默认为 bs
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
dls = TabularDataLoaders.from_csv(path/'adult.csv', path=path, procs=procs, cat_names=cat_names, cont_names=cont_names, 
                                  y_names="salary", valid_idx=list(range(800,1000)), bs=64)

源代码

TabularDataLoaders.test_dl

 TabularDataLoaders.test_dl (test_items, rm_type_tfms=None,
                             process:bool=True, inplace:bool=False, bs=16,
                             shuffle=False, after_batch=None,
                             num_workers=0, verbose:bool=False,
                             do_setup:bool=True, pin_memory=False,
                             timeout=0, batch_size=None, drop_last=False,
                             indexed=None, n=None, device=None,
                             persistent_workers=False,
                             pin_memory_device='', wif=None,
                             before_iter=None, after_item=None,
                             before_batch=None, after_iter=None,
                             create_batches=None, create_item=None,
                             create_batch=None, retain=None,
                             get_idxs=None, sample=None, shuffle_fn=None,
                             do_batch=None)

使用验证 procstest_items 创建测试 TabDataLoader

类型 默认值 详细信息
test_items 用于创建新的测试 TabDataLoader 的条目,格式与训练数据相同
rm_type_tfms NoneType None procs 中移除的 Transform 数量
process bool True 立即将验证 TabularProc 应用于 test_items
inplace bool False 如果为 False,则在内存中保留原始 test_items 的单独副本
bs int 64 批次大小
shuffle bool False 是否打乱数据
after_batch NoneType None
num_workers int None 并行使用的 CPU 核数(默认值:最多 16 个可用核)
verbose bool False 是否打印详细日志
do_setup bool True 是否对批次变换运行 setup()
pin_memory bool False
timeout int 0
batch_size NoneType None
drop_last bool False
indexed NoneType None
n NoneType None
device NoneType None
persistent_workers bool False
pin_memory_device str
wif NoneType None
before_iter NoneType None
after_item NoneType None
before_batch NoneType None
after_iter NoneType None
create_batches NoneType None
create_item NoneType None
create_batch NoneType None
retain NoneType None
get_idxs NoneType None
sample NoneType None
shuffle_fn NoneType None
do_batch NoneType None

外部结构化数据文件可能包含意外的空格,例如在逗号之后。我们可以在 adult.csv 的第一行看到这种情况:"49, Private,101320, ..."。通常需要进行修剪。Pandas 有一个方便的参数 skipinitialspace,它在 TabularDataLoaders.from_csv() 中得以暴露。否则,以后用于推断的类别标签,例如 workclass:Private,如果训练标签被读取为 " Private",则会被错误地归类为 0"#na#"。让我们测试这个功能。

test_data = {
    'age': [49], 
    'workclass': ['Private'], 
    'fnlwgt': [101320],
    'education': ['Assoc-acdm'], 
    'education-num': [12.0],
    'marital-status': ['Married-civ-spouse'], 
    'occupation': [''],
    'relationship': ['Wife'],
    'race': ['White'],
}
input = pd.DataFrame(test_data)
tdl = dls.test_dl(input)

test_ne(0, tdl.dataset.iloc[0]['workclass'])