intel / torch-ccl

oneCCL Bindings for Pytorch*
BSD 3-Clause "New" or "Revised" License
86 stars 25 forks source link

reduce_scatter_tensor raises ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY in multi-node usage #65

Open garrett361 opened 5 months ago

garrett361 commented 5 months ago

Cross posting from this ipex issue.

Repeated calls into torch.dist.reduce_scatter_tensor eventually raise a ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY error in multi-node setups. Similar behavior is found when using Fully Sharded Data Parallel, which calls into reduce_scatter_tensor internally.

Script to reproduce is below. Steps:

  1. Create source and destination tensors on all ranks in a multi-node setup.
  2. Repeatedly reduce_scatter_tensor and print out memory readings at each step
  3. Eventually, the above error is raised (without any corresponding jump in memory readings)
import argparse
import os

import intel_extension_for_pytorch as ipex  # noqa
import oneccl_bindings_for_pytorch  # noqa
import torch
import torch.distributed as dist

def get_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dim",
        type=int,
        default=2**30,
    )
    parser.add_argument(
        "--dtype",
        type=str,
        default="bfloat16",
    )
    parser.add_argument(
        "--max-steps",
        type=int,
        default=100,
    )
    args = parser.parse_args()
    return args

def main(dim: int, dtype: str, max_steps: int) -> None:
    world_size = int(os.environ["WORLD_SIZE"])
    rank = int(os.environ["RANK"])
    local_rank = int(os.environ["LOCAL_RANK"])
    device = torch.device(f"xpu:{local_rank}")
    torch.xpu.set_device(device)

    # Force dim to be divisible by the world size
    new_dim = world_size * (dim // world_size)
    if new_dim != dim:
        if not rank:
            print(
                f"Adjusting original {dim=} to {new_dim} in order to be divisible by {world_size=}",
                flush=True,
            )
        dim = new_dim

    try:
        dist.init_process_group("ccl")

        t_in = torch.randn(dim, dtype=getattr(torch, dtype), device=device)
        t_out = torch.empty(dim // world_size, dtype=getattr(torch, dtype), device=device)

        for step in range(1, max_steps + 1):
            dist.reduce_scatter_tensor(t_out, t_in, op=dist.ReduceOp.SUM)
            torch.xpu.synchronize()
            peak_mem_gib = torch.xpu.memory_stats()["allocated_bytes.all.peak"] / 2**30
            current_mem_gib = torch.xpu.memory_stats()["allocated_bytes.all.current"] / 2**30
            print(f"[{rank=}]: {step=} memory {peak_mem_gib=}, {current_mem_gib=}", flush=True)

    finally:
        dist.destroy_process_group()

if __name__ == "__main__":
    args = get_args()
    main(**vars(args))

Example logs:

[... snip ...]
[rank=14]: step=27 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=13]: step=27 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=17]: step=27 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=6]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=4]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=2]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=8]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=10]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=7]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=1]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=11]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=3]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=9]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=20]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=0]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=5]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=19]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=21]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=22]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=23]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=16]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=15]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=18]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=12]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=14]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=13]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=17]: step=28 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=6]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=11]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=2]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=8]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=10]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=0]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=4]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=1]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=3]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=9]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=20]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=7]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=5]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=23]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=22]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=12]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=18]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=15]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=14]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=16]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=21]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=13]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=19]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
[rank=17]: step=29 memory peak_mem_gib=2.083984375, current_mem_gib=2.083984375
2024:05:29-19:16:18:(202165) |CCL_ERROR| worker.cpp:338 ccl_worker_func: worker 0 caught internal exception: oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
terminate called after throwing an instance of 'ccl::v1::exception'
  what():  oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
2024:05:29-19:16:18:(202162) |CCL_ERROR| worker.cpp:338 ccl_worker_func: worker 0 caught internal exception: oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
terminate called after throwing an instance of 'ccl::v1::exception'
  what():  oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
2024:05:29-19:16:18:(202164) |CCL_ERROR| worker.cpp:338 ccl_worker_func: worker 0 caught internal exception: oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
terminate called after throwing an instance of 'ccl::v1::exception'
  what():  oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
2024:05:29-19:16:18:(202173) |CCL_ERROR| worker.cpp:338 ccl_worker_func: worker 0 caught internal exception: oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
terminate called after throwing an instance of 'ccl::v1::exception'
  what():  oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
2024:05:29-19:16:18:(202167) |CCL_ERROR| worker.cpp:338 ccl_worker_func: worker 0 caught internal exception: oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
terminate called after throwing an instance of 'ccl::v1::exception'
  what():  oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
2024:05:29-19:16:18:(202166) |CCL_ERROR| worker.cpp:338 ccl_worker_func: worker 0 caught internal exception: oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
terminate called after throwing an instance of 'ccl::v1::exception'
  what():  oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
2024:05:29-19:16:18:(149693) |CCL_ERROR| worker.cpp:338 ccl_worker_func: worker 0 caught internal exception: oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
terminate called after throwing an instance of 'ccl::v1::exception'
  what():  oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
2024:05:29-19:16:18:(202168) |CCL_ERROR| worker.cpp:338 ccl_worker_func: worker 0 caught internal exception: oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
terminate called after throwing an instance of 'ccl::v1::exception'
  what():  oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
2024:05:29-19:16:18:(202163) |CCL_ERROR| worker.cpp:338 ccl_worker_func: worker 0 caught internal exception: oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
terminate called after throwing an instance of 'ccl::v1::exception'
  what():  oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
/lus/gila/projects/Aurora_deployment/mk/decoders/alcf/set_torch_dist_env.sh: line 25: 200400 Aborted                 $@
x1921c5s2b0n0.hostmgmt2000.cm.americas.sgi.com: rank 6 exited with code 134
x1921c5s2b0n0.hostmgmt2000.cm.americas.sgi.com: rank 0 died from signal 15
2024:05:29-19:16:18:(149692) |CCL_ERROR| worker.cpp:338 ccl_worker_func: worker 0 caught internal exception: oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
terminate called after throwing an instance of 'ccl::v1::exception'
  what():  oneCCL: ze_call.cpp:28 do_call: EXCEPTION: ze error at zeCommandQueueExecuteCommandLists, code: ZE_RESULT_ERROR_OUT_OF_DEVICE_MEMORY

The behavior seems specific to multi-node setups. I have not seen the same error raised on a single node.