pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.44k stars 454 forks source link

xm.mark_step() RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED #7827

Open radna0 opened 1 month ago

radna0 commented 1 month ago

🐛 Bug

  File "/home/kojoe/EasyAnimate/easyanimate/pipeline/pipeline_easyanimate_inpaint.py", line 1369, in __call__
    latent_model_input = xs.mark_sharding(
  File "/home/kojoe/.local/lib/python3.10/site-packages/torch_xla/distributed/spmd/xla_sharding.py", line 624, in mark_sharding
    annotate_func(unwrap_sharded_tensor(t), op_sharding)
  File "/home/kojoe/.local/lib/python3.10/site-packages/torch_xla/distributed/spmd/xla_sharding.py", line 664, in unwrap_sharded_tensor
    return t.global_tensor
AttributeError: 'XLAShardedTensor' object has no attribute 'global_tensor'

To Reproduce

Steps to reproduce the behavior:

CODE:

import numpy as np
import torch_xla as xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.spmd as xs

xr.use_spmd()

# Define device mesh
num_devices = xr.global_runtime_device_count()

mesh_shape = (2, 2, 2, 1, 1)  # Adjust this based on your TPU configuration
device_ids = np.array(range(num_devices))
mesh = xs.Mesh(device_ids, mesh_shape, ("b", "c", "f", "h", "w"))

partition_spec = ("b", "c", "f", None, None)

latent_model_input = xs.mark_sharding(
    latent_model_input, mesh, partition_spec
)

LOGS:

Number of devices:8
Mesh Shape: OrderedDict([('b', 2), ('c', 2), ('f', 2), ('h', 1), ('w', 1)])
Shape of the latent model input: torch.Size([2, 4, 1, 96, 96])

Expected behavior

Should shard the latent, which is a video latent tensor, (b, c, f, h, w) sharding only the batch, channel, frame dimension and omitting the width and height dimension. The latent should be shared among all devices within the TPU v3-8 instance.

Environment

Additional context

JackCaoG commented 1 month ago

Hmm this is weird, I do see the variable being part of the XLAShardedTensor class

https://github.com/pytorch/xla/blob/ae2a3a4b01fef59159e3dbd8b8da80ee719bde4b/torch_xla/distributed/spmd/xla_sharded_tensor.py#L78

@alanwaketan can you take a look?

alanwaketan commented 1 month ago

Sure, will take a look.

alanwaketan commented 1 month ago

@radna0 Can you tell me how latent_model_input is defined?

radna0 commented 1 month ago

@alanwaketan Here's the code for that, this is a modified version of this https://github.com/aigc-apps/EasyAnimate/blob/main/easyanimate/pipeline/pipeline_easyanimate_inpaint.py#L1071

latent_model_input = (
    torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
latent_model_input = self.scheduler.scale_model_input(
    latent_model_input, t
)

if (
    i < len(timesteps) * (1 - clip_apply_ratio)
    and clip_encoder_hidden_states_input is not None
):
    clip_encoder_hidden_states_actual_input = torch.zeros_like(
        clip_encoder_hidden_states_input
    )
    clip_attention_mask_actual_input = torch.zeros_like(
        clip_attention_mask_input
    )
else:
    clip_encoder_hidden_states_actual_input = (
        clip_encoder_hidden_states_input
    )
    clip_attention_mask_actual_input = clip_attention_mask_input

current_timestep = t
if not torch.is_tensor(current_timestep):
    # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
    # This would be a good case for the `match` statement (Python 3.10+)
    is_mps = latent_model_input.device.type == "mps"
    if isinstance(current_timestep, float):
        dtype = torch.float32 if is_mps else torch.float64
    else:
        dtype = torch.int32 if is_mps else torch.int64
    current_timestep = torch.tensor(
        [current_timestep],
        dtype=dtype,
        device=latent_model_input.device,
    )
elif len(current_timestep.shape) == 0:
    current_timestep = current_timestep[None].to(
        latent_model_input.device
    )
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
current_timestep = current_timestep.expand(latent_model_input.shape[0])

# get shape of the latent model input
shape = latent_model_input.shape
print(
    f"Shape of the latent model input: {shape}, device: {latent_model_input.device}"
)
latent_model_input = xs.mark_sharding(latent_model_input, mesh, partition_spec)
alanwaketan commented 1 month ago

@radna0 Do I miss anything? Not seeing you move the tensor to xla device?

radna0 commented 1 month ago

@alanwaketan Here's the modified code and logs, the tensor is already in the xla device


        # Define device mesh
        num_devices = xr.global_runtime_device_count()
        print(f"Number of devices:{num_devices}")
        mesh_shape = (
            1,
            2,
            4,
            1,
            1,
        )  # Adjust this based on your TPU configuration
        device_ids = np.array(range(num_devices))
        mesh = xs.Mesh(device_ids, mesh_shape, ("b", "c", "f", "h", "w"))
        print(f"Mesh Shape: {mesh.shape()}")
        print(mesh.get_logical_mesh())

        # Define partition_spec (adjust according to the tensor shape and desired sharding)
        partition_spec = ("b", "c", "f", "h", "w")

        # 1. Check inputs. Raise error if not correct
        # 2. Default height and width to transformer

        if initDevice is not None:
            device = initDevice
        else:
            device = self._execution_device

        print("_________________Using Device:", device)

        # 3-9 remains the same

        # 10. Denoising loop
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        self._num_timesteps = len(timesteps)
        if comfyui_progressbar:
            from comfy.utils import ProgressBar

            pbar = ProgressBar(num_inference_steps)

        max_mem = 0

        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):

                latent_model_input = (
                    torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                )

                latent_model_input = self.scheduler.scale_model_input(
                    latent_model_input, t
                ).to(device)

                if (
                    i < len(timesteps) * (1 - clip_apply_ratio)
                    and clip_encoder_hidden_states_input is not None
                ):
                    clip_encoder_hidden_states_actual_input = torch.zeros_like(
                        clip_encoder_hidden_states_input
                    )
                    clip_attention_mask_actual_input = torch.zeros_like(
                        clip_attention_mask_input
                    )
                else:
                    clip_encoder_hidden_states_actual_input = (
                        clip_encoder_hidden_states_input
                    )
                    clip_attention_mask_actual_input = clip_attention_mask_input

                current_timestep = t
                if not torch.is_tensor(current_timestep):
                    # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
                    # This would be a good case for the `match` statement (Python 3.10+)
                    is_mps = latent_model_input.device.type == "mps"
                    if isinstance(current_timestep, float):
                        dtype = torch.float32 if is_mps else torch.float64
                    else:
                        dtype = torch.int32 if is_mps else torch.int64
                    current_timestep = torch.tensor(
                        [current_timestep],
                        dtype=dtype,
                        device=latent_model_input.device,
                    )
                elif len(current_timestep.shape) == 0:
                    current_timestep = current_timestep[None].to(
                        latent_model_input.device
                    )
                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
                current_timestep = current_timestep.expand(latent_model_input.shape[0])

                # get shape of the latent model input
                shape = latent_model_input.shape
                print(
                    f"Shape of the latent model input: {shape}, device: {latent_model_input.device}"
                )
                latent_model_input = xs.mark_sharding(
                    latent_model_input, mesh, partition_spec
                )
Number of devices:8
Mesh Shape: OrderedDict([('b', 1), ('c', 2), ('f', 4), ('h', 1), ('w', 1)])
[[[[[0]]

   [[1]]

   [[2]]

   [[3]]]

  [[[4]]

   [[5]]

   [[6]]

   [[7]]]]]
_________________Using Device: xla:0
Shape of the latent model input: torch.Size([2, 4, 36, 90, 160]), device: xla:0
radna0 commented 1 month ago

Somehow I am not getting AttributeError: 'XLAShardedTensor' object has no attribute 'global_tensor' anymore, and getting a different error, I will edit this issue to xm.mark_step() RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED

Link to full code: https://github.com/radna0/EasyAnimate/blob/f8e2a2ac3ede7c115908ad39c57f9e4f7c299041/easyanimate/pipeline/pipeline_easyanimate_inpaint.py#L1007

  File "/home/kojoe/EasyAnimate/easyanimate/pipeline/pipeline_easyanimate_inpaint.py", line 1433, in __call__
    xm.mark_step()
  File "/home/kojoe/.local/lib/python3.10/site-packages/torch_xla/core/xla_model.py", line 1008, in mark_step
    torch_xla._XLAC._xla_step_marker(
RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: Allocation (size=59808153600) would exceed memory (size=17179869184) :: #allocation5812 [shape = 'f32[32,21600,21600]{1,2,0:T(8,128)}', space=hbm, size = 0xffffffffffffffff, tag = 'output of fusion.124826@{1}'] :: <no-hlo-instruction>
alanwaketan commented 1 month ago

Yea, it's OOMing. How large is the model? And have you shard the model?

radna0 commented 4 weeks ago

The model I believe is 50GB in total size, how would I shard a Diffusion Pipeline? I tried adding this code here to shard the transformer and the vae, but still gets OOM error:


  def print_transformer_layers(self):
      print("Transformer Model:")

      # Print all attributes
      for name, module in self.transformer.named_modules():
          if not name or not hasattr(module, "weight"):
              continue
          try:
              spec = "dp"
              module_shape = len(module.weight.shape)
              if module_shape == 2:
                  spec = ("dp", "fsdp")
              elif module_shape >= 3:
                  needed = module_shape - 2
                  spec = ("dp", "fsdp") + tuple([None] * needed)
              print(module_shape, spec)
              xs.mark_sharding(module.weight, mesh, spec)
              print(f"Sharded Layer: {module}")
          except Exception as e:
              print(f"Error Layer {[res for res in module.named_modules()]}: {e}")
              pass

  def print_vae_layers(self):
      print("VAE Model:")

      # Print all attributes
      for name, module in self.vae.named_modules():
          if (
              not name
              or not module
              or not hasattr(module, "weight")
              or not hasattr(module, "shape")
          ):
              continue
          try:
              spec = "dp"
              module_shape = len(module.weight.shape)
              if module_shape == 2:
                  spec = ("dp", "fsdp")
              elif module_shape >= 3:
                  needed = module_shape - 2
                  spec = ("dp", "fsdp") + tuple([None] * needed)
              print(module_shape, spec)
              xs.mark_sharding(module.weight, mesh, spec)
              print(f"Sharded Layer: {module}")
          except Exception as e:
              print(f"Error Layer {[res for res in module.named_modules()]}: {e}")
              pass
alanwaketan commented 4 weeks ago

Yea 50GB is huge. If you don't shard the model and it will certainly OOM.

@bhavya01 Can you help with diffusion model sharding?

radna0 commented 3 weeks ago

Can anyone help with this?

bhavya01 commented 3 weeks ago

Have you tried sharding the Unet layers? Unet is the biggest part in the diffusion model, and I think sharding that would help the most.

radna0 commented 3 weeks ago

@bhavya01,I see there's the transformer and the vae only, and OOM happens during the denoising loop. Do you mind taking a look here?

https://github.com/radna0/EasyAnimate/blob/f8e2a2ac3ede7c115908ad39c57f9e4f7c299041/easyanimate/pipeline/pipeline_easyanimate_inpaint.py#L1007

JackCaoG commented 3 weeks ago

I think you can take a look at https://github.com/pytorch/xla/blob/master/examples/fsdp/train_decoder_only_fsdp_v2.py#L47-L56 and just pick the layer you want to do fsdp sharding. The idea is that layer you specified will be sharded across all devices and then before the layer execution xla will do a all-gather to bring the parameter to full and then drop the full parameter after the current layer is over.

For most of the decoder layer the model looks like

DecoerLayer1 + DecoerLayer2 + DecoerLayer3 + DecoerLayer4 + ...

In this case we can just shard the DecoerLayer, which means at any given time each device will hold at most 1 layer's full weight.