laekov 遇到的问题是希望使用 mpi 进行一些 cpu 的通信, 以及用 nccl 来进行一些 gpu 间的通信. // 没错, 说的就是让 FastDecode 支持模型并行.
众所周知, pytorch 和 mpi 的关系不太好的亚子. 然而 mpi 其实是最方便, 也是最适合大规模使用的. 虽然新版的 pytorch 试图引入 ucc backend 来支持 cpu 和 gpu 的混合通信, 但在较老的版本或一些特定环境里还无法顺畅地使用 ucc.
事实上 pytorch 支持创建不同 backend 的不同 ProcessGroup. 然而如果写如下一段代码, 就会发现并跑不起来.
dist.init_process_group(backend='mpi')
...
g = dist.new_group(list(range(world_size)), backend='nccl')
x = torch.ones(16, device='cuda') * (rank + 1)
xs = [torch.empty_like(x) for _ in range(world_size)]
dist.all_gather(xs, x, group=g)
甚至会获得一些在 ucx 里的错误栈, 再往上找会发现报错在 c10d::PrefixStore::set
这个函数里.
可以理解的是对于非 mpi 的后端, pytorch 使用了 Store
(基于 tcp 或共享文件系统) 来进行建链前的通信.
那么对于 mpi backend, pytorch 到底使用了什么 store, 成为了这个问题的关键.
阅读了 torch/distributed/distributed_c10d.py
里相应的代码之后, laekov 惊奇地发现, 对于 mpi backend, store 竟然是 None
.
这就不难解释上面这个错误是怎么来的了.
那有没有什么办法补救呢?
pytorch 会把 store 放到 distributed_c10d 的一个全局 map 里, 所以可以自建一个 tcp store, 然后覆盖到这个 map 里.
所以下面这段代码就是 laekov 发明的核心科技了:
rdv = torch.distributed.rendezvous("env://", rank, world_size)
store, _1, _2 = next(rdv)
torch.distributed.distributed_c10d._pg_map[torch.distributed.distributed_c10d._get_default_group()] = ('mpi', store)
果然, 把这段代码插到 new_group
前面, 事情就变得科学了起来. 耶!