DataLoaderX(prefetcher_generator、BackgroundGenerator):
该方案的使用需要额外下载库:
1
| pip install prefetcher_generator
|
具体使用方法(案例):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
| import numpy as np import time from prefetch_generator import BackgroundGenerator, background
# 模拟耗时 batch generator def iterate_minibatches(n_batches=10, batch_size=10): for _ in range(n_batches): time.sleep(0.1) # 模拟数据生成耗时 X = np.random.randn(batch_size, 20) y = np.random.randint(0, 2, batch_size) yield X, y
# ---------------- 普通 generator ---------------- t0 = time.time() for X, y in iterate_minibatches(10): time.sleep(0.1) # 模拟训练 print('!', end=' ') print('\n普通 generator耗时:', time.time()-t0)
# ---------------- BackgroundGenerator ---------------- t1 = time.time() for X, y in BackgroundGenerator(iterate_minibatches(10), max_prefetch=3): time.sleep(0.1) # 模拟训练 print('!', end=' ') print('\nBackgroundGenerator耗时:', time.time()-t1)
# ---------------- @background 装饰器 ---------------- @background(max_prefetch=3) def bg_iterate_minibatches(n_batches=10, batch_size=10): for _ in range(n_batches): time.sleep(0.1) # 模拟数据生成耗时 X = np.random.randn(batch_size, 20) y = np.random.randint(0, 2, batch_size) yield X, y
t2 = time.time() for X, y in bg_iterate_minibatches(10): time.sleep(0.1) # 模拟训练 print('!', end=' ') print('\n@background耗时:', time.time()-t2)
|
同时,也可以用于 DataLoader 的包裹:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
| import torch from torch.utils.data import Dataset, DataLoader from prefetch_generator import BackgroundGenerator import time import numpy as np
# ---------------- 简单 Dataset ---------------- class MyDataset(Dataset): def __init__(self, size=20): self.data = list(range(size)) def __len__(self): return len(self.data) def __getitem__(self, idx): time.sleep(0.05) # 模拟数据读取或 CPU transform return torch.tensor(self.data[idx])
# ---------------- DataLoaderX ---------------- class DataLoaderX(DataLoader): """ 扩展 DataLoader,后台线程异步预取 batch """ def __iter__(self): return BackgroundGenerator(super().__iter__(), max_prefetch=2)
# ---------------- 测试普通 DataLoader ---------------- dataset = MyDataset(20) loader = DataLoader(dataset, batch_size=4, num_workers=2)
t0 = time.time() for batch in loader: time.sleep(0.1) # 模拟训练 print(batch, end=' ') print('\n普通 DataLoader耗时:', time.time()-t0)
# ---------------- 测试 DataLoaderX ---------------- prefetch_loader = DataLoaderX(dataset, batch_size=4, num_workers=2)
t1 = time.time() for batch in prefetch_loader: time.sleep(0.1) # 模拟训练 print(batch, end=' ') print('\nDataLoaderX耗时:', time.time()-t1)
|
下面即为改进后的运行流程图:
1 2 3 4 5 6
| 普通 generator: [生成 batch0][训练 batch0][生成 batch1][训练 batch1]...
BackgroundGenerator: 后台线程: [生成 batch0] [生成 batch1] [生成 batch2] ... 主线程: [训练 batch0] [训练 batch1] [训练 batch2] ...
|
实际上将生成器、训练流程一分为二,异步进行
理论上,能够将原先数据加载后训练的流程变为边加载数据边训练的流程,大幅度提高效率
场景上,理论上最适合取数据时间较短但训练较慢的场景,可以达到完全异步效果
针对取数据时间较长、训练较快场景,同样可以降低启动开销、同步开销,令数据进程一直忙碌
可以认为是比较有效的策略,这个策略同样适用于 DataLoader

补注:
1.上述两个代码均可以运行,可以基于此结果进一步实验效果(验证上述理论)
其官方代码演示案例:
https://github.com/justheuristic/prefetch_generator/blob/master/example.ipynb
2.同时,可以针对数据集数量、预处理时间、主训练时间进行消融,从而得到最佳实践