DataLoader 相关
DataLoader 详细参数(PyTorch 2.7):
1 | def __init__( |
部分参数讲解:(其余属于熟知参数)
sampler:样本索引的生成方法(同 Shuffle 互斥)
batch_sampler:batch 索引列表的生成方法(同 batch_size、shuffle、sampler 互斥)
worker_init_fn:每个子进程启动时执行的初始化函数
multiprocessing_context:指定多进程启动方式(fork、spawn、forkserver)
collate_fn:定义将一个 batch 的样本列表拼接成 tensor 的方法
generator:控制随机数生成的方法
pin_memory_device:设置固定内存的设备
in_order:控制多 worker 时输出 Batch 的顺序
补注:
1.参数上面,worker_init_fn、collate_fn 属于比较常用的方案,可以总结用法
2.pin_memory_device、in_order 需要做相关实验
Dataset 基础框架:
1 | class MyDataset(Dataset): |
基础配合流程:(内部实现流程远不止那么简单,这里只阐述基本操作对流程进行简化)
单进程:
主进程根据 Dataset 返回的 len 划定随机数的取值范围,从中取索引
主进程取索引后通过 getitem 函数获得该索引对应的数据
主进程调用 collate 再将数据转为 Tensor
多进程:
主进程根据 Dataset 返回的 len 划定随机数的取值范围,从中取索引并分发给各个 worker
每个 worker 调用 getitem 函数获得该索引对应的数据
每个 worker 调用 collate 将数据转为 Tensor
主进程接收 Tensor 并返回整个 Batch
注:从上述流程中可以得到 len 必须是数据的全长,否则 DataLoader 无法识别数据集全长
注:多进程模式下,主进程仅分发索引、接收 Tensor,其余均在各 worker 内实现,包括 collate
注:因而,getitem、collate 均可以用于数据预处理,都可以利用并行加速
相关实验结论:
多进程基本原理、多进程运行原理、多进程模式、多进程并行、多进程数量、内存相关事项、持久化进程、DataLoaderX、不同文件格式的读取与加速策略、锁页内存与异步拷贝、数据预处理及其余优化方案