pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.48k stars 480 forks source link

Cannot move tensors to cpu when in a xmp spawn process #8271

Open radna0 opened 4 weeks ago

radna0 commented 4 weeks ago

🐛 Bug

all_frames = torch.cat(all_frames, dim=0).cpu().numpy()
RuntimeError: Bad StatusOr access: INTERNAL: during context [pre-optimization]: RET_CHECK failure (third_party/tensorflow/compiler/xla/service/hlo_verifier.cc:402) replica_count == 1 || n == replica_count In kCrossReplica mode, replica groups should contain 8 replicas, but found 2: %all-gather.20571 = f16[81920,8,64]{2,1,0} all-gather(f16[40960,8,64]{2,1,0} %add.20570), replica_groups={{0,1}}, dimensions={0}

To Reproduce

Steps to reproduce the behavior:

  1. spawn a process, with xmp spawn
  2. Move tensors to cpu using .cpu

Expected behavior

Should move tensors to cpu.

Environment

Additional context

JackCaoG commented 4 weeks ago

do you have a small repo code?

radna0 commented 4 weeks ago

You can clone this repo here git clone https://github.com/radna0/Video-Infinity.git install requirements by using pip install -r requirements.txt and test out the code using accelerate launch tpu_inference.py --config examples/config.json Let me know if I am missing anything

radna0 commented 3 weeks ago

Were you able to reproduce the error? @JackCaoG

radna0 commented 3 weeks ago

It's been a week, and I''m still encountering this problem. I have tried different methods for example: dist.gather(), tensor.cpu(), tensor.contiguous() or other methods related to saving tensors also moves to CPU and run into the same problem here. Even with xm.mark_step(). There is no other way around this and it has always been the same error replica groups should contain 8 replicas, but found 2. Is there something wrong that I could be doing here? What I'm basically doing is the following:

  1. Spawn processes using torch_xla.launch()
  2. For each rank, declare a distributed controller,

    
    class DistController(object):
    def __init__(self, rank, world_size, config) -> None:
        super().__init__()
        self.rank = rank
        self.world_size = world_size
        self.config = config
        self.is_master = rank == 0
        self.device = torch_xla.device()
        self.init_dist()
        self.init_group()
    
    def init_dist(self):
        print(
            f"Rank {self.rank}, {self.device} / {self.world_size} is running on XLA device."
        )
        os.environ["MASTER_ADDR"] = "127.0.0.1"
        os.environ["MASTER_PORT"] = str(self.config.get("master_port") or "29500")
        dist.init_process_group("xla", rank=self.rank, world_size=self.world_size)
    
    def init_group(self):
        self.adj_groups = [
            dist.new_group([i, i + 1]) for i in range(self.world_size - 1)
        ]
        print(f"Rank {self.rank} initialized groups: {self.adj_groups}")
3. init the model and move it to the xla device, then do inference.

concurrent.futures.process._RemoteTraceback: """ Traceback (most recent call last): File "/usr/lib/python3.10/concurrent/futures/process.py", line 246, in _process_worker r = call_item.fn(*call_item.args, call_item.kwargs) File "/usr/lib/python3.10/concurrent/futures/process.py", line 205, in _process_chunk return [fn(args) for args in chunk] File "/usr/lib/python3.10/concurrent/futures/process.py", line 205, in return [fn(args) for args in chunk] File "/home/kojoe/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 77, in _run_thread_per_device replica_results = list( File "/usr/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator yield _result_or_cancel(fs.pop()) File "/usr/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel return fut.result(timeout) File "/usr/lib/python3.10/concurrent/futures/_base.py", line 458, in result return self.get_result() File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in get_result raise self._exception File "/usr/lib/python3.10/concurrent/futures/thread.py", line 58, in run result = self.fn(*self.args, *self.kwargs) File "/home/kojoe/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 70, in _thread_fn return fn() File "/home/kojoe/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 185, in call self.fn(runtime.global_ordinal(), self.args, self.kwargs) File "/home/kojoe/Video-Infinity/tpu_inference.py", line 86, in main obj = run_inference(rank, config) File "/home/kojoe/Video-Infinity/tpu_inference.py", line 59, in run_inference obj = dist_pipe.inference( File "/home/kojoe/Video-Infinity/src/video_infinity/wrapper.py", line 241, in inference xm.mark_step() File "/home/kojoe/.local/lib/python3.10/site-packages/torch_xla/core/xla_model.py", line 1046, in mark_step torch_xla._XLAC._xla_step_marker( RuntimeError: Bad StatusOr access: INTERNAL: during context [pre-optimization]: RET_CHECK failure (third_party/tensorflow/compiler/xla/service/hlo_verifier.cc:402) replica_count == 1 || n == replica_count In kCrossReplica mode, replica groups should contain 8 replicas, but found 2: %all-gather.320 = f16[81920,8,64]{2,1,0} all-gather(f16[40960,8,64]{2,1,0} %add.319), replica_groups={{0,1}}, dimensions={0} """

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/home/kojoe/Video-Infinity/tpu_inference.py", line 102, in torch_xla.launch( File "/home/kojoe/.local/lib/python3.10/site-packages/torch_xla/torch_xla.py", line 233, in launch xmp.spawn(fn, args=args, nprocs=nprocs, start_method=start_method) File "/home/kojoe/.local/lib/python3.10/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 39, in spawn return pjrt.spawn(fn, nprocs, start_method, args) File "/home/kojoe/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 213, in spawn run_multiprocess(spawn_fn, start_method=start_method) File "/home/kojoe/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 169, in run_multiprocess replica_results = list( File "/home/kojoe/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 170, in itertools.chain.from_iterable( File "/usr/lib/python3.10/concurrent/futures/process.py", line 575, in _chain_from_iterable_of_lists for element in iterable: File "/usr/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator yield _result_or_cancel(fs.pop()) File "/usr/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel return fut.result(timeout) File "/usr/lib/python3.10/concurrent/futures/_base.py", line 458, in result return self.get_result() File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in get_result raise self._exception RuntimeError: Bad StatusOr access: INTERNAL: during context [pre-optimization]: RET_CHECK failure (third_party/tensorflow/compiler/xla/service/hlo_verifier.cc:402) replica_count == 1 || n == replica_count In kCrossReplica mode, replica groups should contain 8 replicas, but found 2: %all-gather.320 = f16[81920,8,64]{2,1,0} all-gather(f16[40960,8,64]{2,1,0} %add.319), replica_groups={{0,1}}, dimensions={0}