laekov 遇到的问题是希望使用 mpi 进行一些 cpu 的通信, 以及用 nccl 来进行一些 gpu 间的通信. // 没错, 说的就是让 FastDecode 支持模型并行.

众所周知, pytorch 和 mpi 的关系不太好的亚子. 然而 mpi 其实是最方便, 也是最适合大规模使用的. 虽然新版的 pytorch 试图引入 ucc backend 来支持 cpu 和 gpu 的混合通信, 但在较老的版本或一些特定环境里还无法顺畅地使用 ucc.

事实上 pytorch 支持创建不同 backend 的不同 ProcessGroup. 然而如果写如下一段代码, 就会发现并跑不起来.

    dist.init_process_group(backend='mpi')
    ...
    g = dist.new_group(list(range(world_size)), backend='nccl')
    x = torch.ones(16, device='cuda') * (rank + 1)
    xs = [torch.empty_like(x) for _ in range(world_size)]
    dist.all_gather(xs, x, group=g)

甚至会获得一些在 ucx 里的错误栈, 再往上找会发现报错在 c10d::PrefixStore::set 这个函数里. 可以理解的是对于非 mpi 的后端, pytorch 使用了 Store (基于 tcp 或共享文件系统) 来进行建链前的通信. 那么对于 mpi backend, pytorch 到底使用了什么 store, 成为了这个问题的关键.

阅读了 torch/distributed/distributed_c10d.py 里相应的代码之后, laekov 惊奇地发现, 对于 mpi backend, store 竟然是 None. 这就不难解释上面这个错误是怎么来的了.

那有没有什么办法补救呢?

pytorch 会把 store 放到 distributed_c10d 的一个全局 map 里, 所以可以自建一个 tcp store, 然后覆盖到这个 map 里.

所以下面这段代码就是 laekov 发明的核心科技了:

    rdv = torch.distributed.rendezvous("env://", rank, world_size)
    store, _1, _2 = next(rdv)
    torch.distributed.distributed_c10d._pg_map[torch.distributed.distributed_c10d._get_default_group()] = ('mpi', store)

果然, 把这段代码插到 new_group 前面, 事情就变得科学了起来. 耶!