batch_isend_irecv支持同时配置多组P2P算子,但是其底层实现上也是拆开来做的。以下是NCCL的几种测试场景。

  1. 配置多组p2p对
import os
import torch
import argparse
import torch.distributed as dist
from torch.distributed import ReduceOp
from datetime import datetime
import time
import argparse
import numpy as np
import torch.distributed
 

def main():
    dist.init_process_group(backend='nccl')
    if not torch.distributed.is_initialized():
        return

    torch.manual_seed(1)
    world_size = torch.distributed.get_world_size()
    rank = torch.distributed.get_rank()
    local_rank=int(os.environ['LOCAL_RANK'])
    print("local rank is:", local_rank)
    torch.cuda.set_device(local_rank)

    send_tensor = []
    recv_tensor = []
    
    ops_list = []
    for i in range(256):
        send_tensor.append(torch.ones((1280,1280), dtype=torch.float32, device=f'cuda') * rank + i * 0.001)
        recv_tensor.append(torch.randn((1280,1280), dtype=torch.float32, device=f'cuda'))
    
    for i in range(256):
        send_op = dist.P2POp(dist.isend, send_tensor[i], (rank + 1)%world_size)
        recv_op = dist.P2POp(dist.irecv, recv_tensor[i], (rank - 1 + world_size)%world_size)
        ops_list.append(send_op)
        ops_list.append(recv_op)
        
    reqs = dist.batch_isend_irecv(ops_list)
    torch.distributed.barrier()

    import time
    time.sleep(rank * 3)
    
    for i in range(256):
        print("recv tensor is:", i, recv_tensor[i].reshape(-1), torch.distributed.get_rank())

         
if __name__=='__main__':
    main()
  1. 配置连续多组通信
import os
import torch
import argparse
import torch.distributed as dist
from torch.distributed import ReduceOp
from datetime import datetime
import time
import argparse
import numpy as np
import torch.distributed
 

def main():
    dist.init_process_group(backend='nccl')
    if not torch.distributed.is_initialized():
        return

    torch.manual_seed(1)
    world_size = torch.distributed.get_world_size()
    rank = torch.distributed.get_rank()
    local_rank=int(os.environ['LOCAL_RANK'])
    print("local rank is:", local_rank)
    torch.cuda.set_device(local_rank)

    send_tensor = []
    recv_tensor = []
    
    ops_list = []
    for i in range(256):
        send_tensor.append(torch.ones((1280,1280), dtype=torch.float32, device=f'cuda') * rank + i * 0.001)
        recv_tensor.append(torch.randn((1280,1280), dtype=torch.float32, device=f'cuda'))
    
    for i in range(256):
        send_op = dist.P2POp(dist.isend, send_tensor[i], (rank + 1)%world_size)
        ops_list.append(send_op)
    for i in range(256):
        recv_op = dist.P2POp(dist.irecv, recv_tensor[i], (rank - 1 + world_size)%world_size)
        ops_list.append(recv_op)
        
    reqs = dist.batch_isend_irecv(ops_list)
    torch.distributed.barrier()

    import time
    time.sleep(rank * 3)
    
    for i in range(256):
        print("recv tensor is:", i, recv_tensor[i].reshape(-1), torch.distributed.get_rank())

         
if __name__=='__main__':
    main()
  1. 配置多卡多组通信
import os
import torch
import argparse
import torch.distributed as dist
from torch.distributed import ReduceOp
from datetime import datetime
import time
import argparse
import numpy as np
import torch.distributed
 

def main():
    dist.init_process_group(backend='nccl')
    if not torch.distributed.is_initialized():
        return

    torch.manual_seed(1)
    world_size = torch.distributed.get_world_size()
    rank = torch.distributed.get_rank()
    local_rank=int(os.environ['LOCAL_RANK'])
    print("local rank is:", local_rank)
    torch.cuda.set_device(local_rank)

    recv_tensors = [None for _ in range(world_size)]
    expected_tensors = [None for _ in range(world_size)]
    p2p_op_list = []
    for i in range(32):
        for src in range(0, world_size):
            send_tensor = torch.empty(rank + 1, rank + 1, rank + 1, dtype=torch.float).fill_(src).cuda(rank)
            recv_tensors[src] = torch.empty(src + 1, src + 1, src + 1, dtype=torch.float).fill_(-1).cuda(rank)
            expected_tensors[src] = torch.empty(src + 1, src + 1, src + 1, dtype=torch.float).fill_(rank)
            recv_op = dist.P2POp(dist.irecv, recv_tensors[src], src)
            p2p_op_list.append(recv_op)
            send_op = dist.P2POp(dist.isend, send_tensor, src)
            p2p_op_list.append(send_op)

    reqs = dist.batch_isend_irecv(p2p_op_list)
    for req in reqs:
        req.wait()

    import time
    time.sleep(rank * 3)
    
    for i in range(world_size):
        print("recv tensor is:", i, recv_tensors[i].reshape(-1), torch.distributed.get_rank())
        print("expect tensor is:", i, expected_tensors[i].reshape(-1), torch.distributed.get_rank())
        print("\n")

         
if __name__=='__main__':
    main()

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐