多进程运行原理

1
2
3
4
# prime the prefetch loop //初始化
for _ in range(self._prefetch_factor * self._num_workers):
self._try_put_index()

初始化时填充索引,令初始就有 prefetch_factor * num_workers 数据索引在内

1
2
3
4
5
6
7
# 取得数据
def _process_data(self, data, worker_idx):
self._workers_num_tasks[worker_idx] -= 1
self._try_put_index() # 取索引
if isinstance(data, ExceptionWrapper):
data.reraise()
return data

根据索引取得数据,同时每取得一批数据都会调用函数再取一批的索引放入队列内

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# 放置索引
def _try_put_index(self):
max_tasks = self._prefetch_factor * self._num_workers
assert self._tasks_outstanding < max_tasks

try:
index = self._next_index()
except StopIteration:
return
for _ in range(self._num_workers): # find the next active worker, if any
worker_queue_idx = next(self._worker_queue_idx_cycle)
if self._workers_status[worker_queue_idx]:
if self._in_order:
break
elif self._workers_num_tasks[worker_queue_idx] < max_tasks // sum(
self._workers_status
):
break
else:
# not found (i.e., didn't break)
return

self._index_queues[worker_queue_idx].put((self._send_idx, index)) # type: ignore[possibly-undefined]
self._task_info[self._send_idx] = (worker_queue_idx,)
self._workers_num_tasks[worker_queue_idx] += 1
self._tasks_outstanding += 1
self._send_idx += 1

该函数实现索引的获取和装入队列(每个 worker 存在一个索引队列)

1
2
3
4
5
6
7
8
9
10
11
12
while watchdog.is_alive():
try:
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) # 等待任务(队列阻塞时等待,其余时间直接瞬时取得)
except queue.Empty:
continue

# 处理各种信号...

idx, index = r # 解包任务
data = fetcher.fetch(index) # 立即获取数据
data_queue.put((idx, data)) # 立即放入输出队列
del data, idx, index, r # 立即清理内存

该函数为 worker 的工作循环,其逻辑是每当接收到一个索引就去取得对应索引的数据并放入输出

从理论上,上述行为已经实现异步的数据读取,即边消耗数据边取得数据

代码中使用队列实现这一功能,实际上也可以作为很好的异步并发思路

其效果如下:

1
2
3
4

主进程: 训练 batch0 ────┐ 训练 batch1 ────┐ 训练 batch2 ────┐ ...
主线程: 预取 batch1── 预取 batch2── 预取 batch3── ...