Open radna0 opened 1 month ago
Hmm this is weird, I do see the variable being part of the XLAShardedTensor class
@alanwaketan can you take a look?
Sure, will take a look.
@radna0 Can you tell me how latent_model_input
is defined?
@alanwaketan Here's the code for that, this is a modified version of this https://github.com/aigc-apps/EasyAnimate/blob/main/easyanimate/pipeline/pipeline_easyanimate_inpaint.py#L1071
latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
if (
i < len(timesteps) * (1 - clip_apply_ratio)
and clip_encoder_hidden_states_input is not None
):
clip_encoder_hidden_states_actual_input = torch.zeros_like(
clip_encoder_hidden_states_input
)
clip_attention_mask_actual_input = torch.zeros_like(
clip_attention_mask_input
)
else:
clip_encoder_hidden_states_actual_input = (
clip_encoder_hidden_states_input
)
clip_attention_mask_actual_input = clip_attention_mask_input
current_timestep = t
if not torch.is_tensor(current_timestep):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
current_timestep = torch.tensor(
[current_timestep],
dtype=dtype,
device=latent_model_input.device,
)
elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(
latent_model_input.device
)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
current_timestep = current_timestep.expand(latent_model_input.shape[0])
# get shape of the latent model input
shape = latent_model_input.shape
print(
f"Shape of the latent model input: {shape}, device: {latent_model_input.device}"
)
latent_model_input = xs.mark_sharding(latent_model_input, mesh, partition_spec)
@radna0 Do I miss anything? Not seeing you move the tensor to xla device?
@alanwaketan Here's the modified code and logs, the tensor is already in the xla device
# Define device mesh
num_devices = xr.global_runtime_device_count()
print(f"Number of devices:{num_devices}")
mesh_shape = (
1,
2,
4,
1,
1,
) # Adjust this based on your TPU configuration
device_ids = np.array(range(num_devices))
mesh = xs.Mesh(device_ids, mesh_shape, ("b", "c", "f", "h", "w"))
print(f"Mesh Shape: {mesh.shape()}")
print(mesh.get_logical_mesh())
# Define partition_spec (adjust according to the tensor shape and desired sharding)
partition_spec = ("b", "c", "f", "h", "w")
# 1. Check inputs. Raise error if not correct
# 2. Default height and width to transformer
if initDevice is not None:
device = initDevice
else:
device = self._execution_device
print("_________________Using Device:", device)
# 3-9 remains the same
# 10. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
if comfyui_progressbar:
from comfy.utils import ProgressBar
pbar = ProgressBar(num_inference_steps)
max_mem = 0
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
).to(device)
if (
i < len(timesteps) * (1 - clip_apply_ratio)
and clip_encoder_hidden_states_input is not None
):
clip_encoder_hidden_states_actual_input = torch.zeros_like(
clip_encoder_hidden_states_input
)
clip_attention_mask_actual_input = torch.zeros_like(
clip_attention_mask_input
)
else:
clip_encoder_hidden_states_actual_input = (
clip_encoder_hidden_states_input
)
clip_attention_mask_actual_input = clip_attention_mask_input
current_timestep = t
if not torch.is_tensor(current_timestep):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
current_timestep = torch.tensor(
[current_timestep],
dtype=dtype,
device=latent_model_input.device,
)
elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(
latent_model_input.device
)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
current_timestep = current_timestep.expand(latent_model_input.shape[0])
# get shape of the latent model input
shape = latent_model_input.shape
print(
f"Shape of the latent model input: {shape}, device: {latent_model_input.device}"
)
latent_model_input = xs.mark_sharding(
latent_model_input, mesh, partition_spec
)
Number of devices:8
Mesh Shape: OrderedDict([('b', 1), ('c', 2), ('f', 4), ('h', 1), ('w', 1)])
[[[[[0]]
[[1]]
[[2]]
[[3]]]
[[[4]]
[[5]]
[[6]]
[[7]]]]]
_________________Using Device: xla:0
Shape of the latent model input: torch.Size([2, 4, 36, 90, 160]), device: xla:0
Somehow I am not getting AttributeError: 'XLAShardedTensor' object has no attribute 'global_tensor'
anymore, and getting a different error, I will edit this issue to xm.mark_step() RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED
Link to full code: https://github.com/radna0/EasyAnimate/blob/f8e2a2ac3ede7c115908ad39c57f9e4f7c299041/easyanimate/pipeline/pipeline_easyanimate_inpaint.py#L1007
File "/home/kojoe/EasyAnimate/easyanimate/pipeline/pipeline_easyanimate_inpaint.py", line 1433, in __call__
xm.mark_step()
File "/home/kojoe/.local/lib/python3.10/site-packages/torch_xla/core/xla_model.py", line 1008, in mark_step
torch_xla._XLAC._xla_step_marker(
RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: Allocation (size=59808153600) would exceed memory (size=17179869184) :: #allocation5812 [shape = 'f32[32,21600,21600]{1,2,0:T(8,128)}', space=hbm, size = 0xffffffffffffffff, tag = 'output of fusion.124826@{1}'] :: <no-hlo-instruction>
Yea, it's OOMing. How large is the model? And have you shard the model?
The model I believe is 50GB in total size, how would I shard a Diffusion Pipeline? I tried adding this code here to shard the transformer and the vae, but still gets OOM error:
def print_transformer_layers(self):
print("Transformer Model:")
# Print all attributes
for name, module in self.transformer.named_modules():
if not name or not hasattr(module, "weight"):
continue
try:
spec = "dp"
module_shape = len(module.weight.shape)
if module_shape == 2:
spec = ("dp", "fsdp")
elif module_shape >= 3:
needed = module_shape - 2
spec = ("dp", "fsdp") + tuple([None] * needed)
print(module_shape, spec)
xs.mark_sharding(module.weight, mesh, spec)
print(f"Sharded Layer: {module}")
except Exception as e:
print(f"Error Layer {[res for res in module.named_modules()]}: {e}")
pass
def print_vae_layers(self):
print("VAE Model:")
# Print all attributes
for name, module in self.vae.named_modules():
if (
not name
or not module
or not hasattr(module, "weight")
or not hasattr(module, "shape")
):
continue
try:
spec = "dp"
module_shape = len(module.weight.shape)
if module_shape == 2:
spec = ("dp", "fsdp")
elif module_shape >= 3:
needed = module_shape - 2
spec = ("dp", "fsdp") + tuple([None] * needed)
print(module_shape, spec)
xs.mark_sharding(module.weight, mesh, spec)
print(f"Sharded Layer: {module}")
except Exception as e:
print(f"Error Layer {[res for res in module.named_modules()]}: {e}")
pass
Yea 50GB is huge. If you don't shard the model and it will certainly OOM.
@bhavya01 Can you help with diffusion model sharding?
Can anyone help with this?
Have you tried sharding the Unet layers? Unet is the biggest part in the diffusion model, and I think sharding that would help the most.
@bhavya01,I see there's the transformer and the vae only, and OOM happens during the denoising loop. Do you mind taking a look here?
I think you can take a look at https://github.com/pytorch/xla/blob/master/examples/fsdp/train_decoder_only_fsdp_v2.py#L47-L56 and just pick the layer you want to do fsdp sharding. The idea is that layer you specified will be sharded across all devices and then before the layer execution xla will do a all-gather
to bring the parameter to full and then drop the full parameter after the current layer is over.
For most of the decoder layer the model looks like
DecoerLayer1 + DecoerLayer2 + DecoerLayer3 + DecoerLayer4 + ...
In this case we can just shard the DecoerLayer
, which means at any given time each device will hold at most 1 layer's full weight.
🐛 Bug
To Reproduce
Steps to reproduce the behavior:
CODE:
LOGS:
Expected behavior
Should shard the latent, which is a video latent tensor, (b, c, f, h, w) sharding only the batch, channel, frame dimension and omitting the width and height dimension. The latent should be shared among all devices within the TPU v3-8 instance.
Environment
Additional context