huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.63k stars 925 forks source link

RuntimeError: weight should have at least three dimensions #2301

Closed theoden8 closed 1 month ago

theoden8 commented 8 months ago

System Info

System info ```Shell Using 2 GPUs nvidia 2080ti # lsb_release -a Distributor ID: Ubuntu Description: Ubuntu 20.04.4 LTS Release: 20.04 Codename: focal # conda list | grep -E '(torch|transformers|accelerate|diffusers|deepspeed|numpy)' accelerate 0.25.0 pyhd8ed1ab_0 conda-forge deepspeed 0.12.2 cpu_py310h11dbdba_1 conda-forge diffusers 0.25.0 pyhd8ed1ab_0 conda-forge ffmpeg 4.3 hf484d3e_0 pytorch numpy 1.26.2 py310hb13e2d6_0 conda-forge pytorch 2.1.1 py3.10_cuda11.8_cudnn8.7.0_0 pytorch pytorch-cuda 11.8 h7e8668a_5 pytorch pytorch-lightning 2.1.2 pyhd8ed1ab_0 conda-forge pytorch-mutex 1.0 cuda pytorch torchaudio 2.1.1 py310_cu118 pytorch torchinfo 1.8.0 pyhd8ed1ab_0 conda-forge torchmetrics 1.2.1 pyhd8ed1ab_0 conda-forge torchtriton 2.1.0 py310 pytorch torchvision 0.16.1 py310_cu118 pytorch transformers 4.35.2 pyhd8ed1ab_0 conda-forge ```

Information

Tasks

Reproduction

I am trying to train stable diffusion u-net with accelerate on multiple small GPUs. Here's my training script (adapted from https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet.py)

pipeline_accelerate_fsdp.py ```python #!/usr/bin/env python3 import os import sys import random import typing import numpy as np import tqdm.auto as tqdm import torch import torchvision import datasets import transformers import diffusers import accelerate import xformers import packaging if __name__ == '__main__': accelerator = accelerate.Accelerator() accelerator.print('fsdp', accelerator.local_process_index) sd_model = 'runwayml/stable-diffusion-v1-5' output_dir = 'output_ct_simple' pipeline = diffusers.StableDiffusionPipeline.from_pretrained( sd_model, torch_dtype=torch.float32, device_map=None, safety_checker=None, ) feature_extrator = pipeline.feature_extractor vae = pipeline.vae unet = pipeline.unet tokenizer = pipeline.tokenizer#.to(device=device) text_encoder = pipeline.text_encoder noise_scheduler = pipeline.scheduler vae.requires_grad_(True) unet.requires_grad_(False) text_encoder.requires_grad_(True) #tokenizer.requires_grad_(False) #print('after loading SDv1.5') #print_gpu_memory_usage() assert packaging.version.parse(xformers.__version__) != packaging.version.parse("0.0.16") unet.enable_xformers_memory_efficient_attention() dataset = datasets.load_dataset('fusing/fill50k', None, cache_dir=None) column_names = dataset['train'].column_names image_column = 'image' caption_column = 'text' conditioning_image_column = 'conditioning_image' def tokenize_captions(examples, is_train=True): proportion_empty_prompts = .0 captions = [] for caption in examples[caption_column]: if random.random() < proportion_empty_prompts: captions.append("") elif isinstance(caption, str): captions.append(caption) elif isinstance(caption, (list, np.ndarray)): # take a random caption if there are multiple captions.append(random.choice(caption) if is_train else caption[0]) else: raise ValueError( f"Caption column `{caption_column}` should contain either strings or lists of strings." ) inputs = tokenizer( captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ) return inputs.input_ids train_batch_size = 4 resolution = 128 image_transforms = torchvision.transforms.Compose([ torchvision.transforms.Resize(resolution, interpolation=torchvision.transforms.InterpolationMode.BILINEAR), torchvision.transforms.CenterCrop(resolution), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize([0.5], [0.5]), ]) conditioning_image_transforms = torchvision.transforms.Compose([ torchvision.transforms.Resize(resolution, interpolation=torchvision.transforms.InterpolationMode.BILINEAR), torchvision.transforms.CenterCrop(resolution), torchvision.transforms.ToTensor(), ]) def preprocess_train(examples): images = [image.convert("RGB") for image in examples[image_column]] images = [image_transforms(image) for image in images] conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]] conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images] examples["pixel_values"] = images examples["conditioning_pixel_values"] = conditioning_images examples["input_ids"] = tokenize_captions(examples) return examples def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples]) conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float() input_ids = torch.stack([example["input_ids"] for example in examples]) return { "pixel_values": pixel_values, "conditioning_pixel_values": conditioning_pixel_values, "input_ids": input_ids, } train_dataset = dataset["train"].with_transform(preprocess_train) num_workers = 0 # main process train_dataloader = torch.utils.data.DataLoader( train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=train_batch_size, num_workers=num_workers, ) learning_rate = 1e-4 trainable_params = [p for p in list(unet.parameters()) + list(vae.parameters()) + list(text_encoder.parameters()) if p.requires_grad] optimizer = torch.optim.Adam( trainable_params, lr=learning_rate, betas=(.9, .999), weight_decay=1e-2, eps=1e-8, ) lr_warmup_steps = 500 lr_num_cycles = 1 lr_power = 1. num_processes = 1 epochs = 10 lr_scheduler = diffusers.optimization.get_scheduler( 'constant', optimizer=optimizer, num_warmup_steps=lr_warmup_steps, num_training_steps=epochs * len(train_dataloader), num_cycles=lr_num_cycles, power=lr_power, ) unet, vae, text_encoder, tokenizer = accelerator.prepare( unet, vae, text_encoder, tokenizer ) optimizer, lr_scheduler, train_dataloader = accelerator.prepare( optimizer, lr_scheduler, train_dataloader ) if accelerator.is_main_process: pbar = tqdm.trange(epochs * len(train_dataset), initial=0, desc='steps', leave=True) for epoch in range(epochs): for batch in train_dataloader: pixel_values = batch["pixel_values"] pixel_values = pixel_values#.to(device=vae.device) latents = vae.encode(pixel_values).latent_dist.sample() latents = latents * vae.config.scaling_factor # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning input_ids = batch['input_ids']#.to(device=text_encoder.device) encoder_hidden_states = text_encoder(input_ids)[0]#.float() pred = unet( noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, ).sample if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") loss = torch.nn.functional.mse_loss(pred, target, reduction="mean") optimizer.zero_grad() accelerator.backward(loss) optimizer.step() lr_scheduler.step() if accelerator.is_main_process: pbar.set_postfix(loss=loss.detach().cpu().item()) pbar.update(len(pixel_values)) ```
accelerator_fsdp.yaml ```yaml compute_environment: LOCAL_MACHINE debug: false distributed_type: FSDP downcast_bf16: 'no' fsdp_config: fsdp_auto_wrap_policy: SIZE_BASED_WRAP fsdp_backward_prefetch_policy: BACKWARD_PRE fsdp_cpu_ram_efficient_loading: false fsdp_forward_prefetch: true fsdp_min_num_params: 100000000 fsdp_offload_params: false fsdp_sharding_strategy: 1 fsdp_state_dict_type: FULL_STATE_DICT fsdp_sync_module_states: true fsdp_use_orig_params: true machine_rank: 0 main_training_function: main mixed_precision: 'no' num_machines: 1 num_processes: 2 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false ```
run_pp_accelerate_fsdp.sh ```bash #!/usr/bin/env bash set -e eval "$("${HOME}"/.miniconda3/bin/conda shell.bash hook)" conda activate hfkernel set -x CUDA_HOME="/usr/local/cuda-11.1" \ OMP_NUM_THREADS=1 \ CUDA_VISIBLE_DEVICES=0,1 \ python -m accelerate.commands.launch \ --config_file ./accelerate_fsdp.yaml \ pipeline_accelerate_fsdp.py ```

I get this error:

+ CUDA_HOME=/usr/local/cuda-11.1
+ OMP_NUM_THREADS=1
+ CUDA_VISIBLE_DEVICES=0,1
+ python -m accelerate.commands.launch --config_file ./accelerate_fsdp.yaml pipeline_accelerate_fsdp.py
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
fsdp 0
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00,  6.94it/s]
You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00,  6.97it/s]
You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .
steps:   0%|                                                                                                                                      | 0/500000 [00:00<?, ?it/s]Traceback (most recent call last):
  File "$PWD/pipeline_accelerate_fsdp.py", line 201, in <module>
    latents = vae.encode(pixel_values).latent_dist.sample()
  File "$CONDA_PREFIX/lib/python3.10/site-packages/diffusers/utils/accelerate_utils.py", line 46, in wrapper
    return method(self, *args, **kwargs)
  File "$CONDA_PREFIX/lib/python3.10/site-packages/diffusers/models/autoencoders/autoencoder_kl.py", line 260, in encode
    h = self.encoder(x)
  File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "$CONDA_PREFIX/lib/python3.10/site-packages/diffusers/models/autoencoders/vae.py", line 143, in forward
    sample = self.conv_in(sample)
  File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 460, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: weight should have at least three dimensions
Traceback (most recent call last):
  File "$PWD/pipeline_accelerate_fsdp.py", line 201, in <module>
    latents = vae.encode(pixel_values).latent_dist.sample()
  File "$CONDA_PREFIX/lib/python3.10/site-packages/diffusers/utils/accelerate_utils.py", line 46, in wrapper
    return method(self, *args, **kwargs)
  File "$CONDA_PREFIX/lib/python3.10/site-packages/diffusers/models/autoencoders/autoencoder_kl.py", line 260, in encode
    h = self.encoder(x)
  File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "$CONDA_PREFIX/lib/python3.10/site-packages/diffusers/models/autoencoders/vae.py", line 143, in forward
    sample = self.conv_in(sample)
  File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 460, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: weight should have at least three dimensions
steps:   0%|                                                                                                                                      | 0/500000 [00:00<?, ?it/s]
[2024-01-02 15:16:24,291] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 3413506) of binary: $CONDA_PREFIX/bin/python
Traceback (most recent call last):
  File "$CONDA_PREFIX/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "$CONDA_PREFIX/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "$CONDA_PREFIX/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1027, in <module>
    main()
  File "$CONDA_PREFIX/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1023, in main
    launch_command(args)
  File "$CONDA_PREFIX/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1004, in launch_command
    multi_gpu_launcher(args)
  File "$CONDA_PREFIX/lib/python3.10/site-packages/accelerate/commands/launch.py", line 666, in multi_gpu_launcher
    distrib_run.run(args)
  File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/run.py", line 797, in run
    elastic_launch(
  File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
pipeline_accelerate_fsdp.py FAILED
------------------------------------------------------------
Failures:
[1]:
  time      : 2024-01-02_15:16:24
  host      : ccp3.clostra.com
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 3413507)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-01-02_15:16:24
  host      : ccp3.clostra.com
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 3413506)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

Expected behavior

I expected accelerator to launch FSDP with model parallelism and not pipeline parallelism, and shard models across available devices.

I understand that it flattens tensors in the underlying representation and shards it to different devices, so I'm trying to understand:

1.Why nvidia-smi shows that model loads to all devices, not sure if that's even sharded.

  1. I've tried printing parameter shapes before and after accelerate.prepare, the shape products (numbers of scalars per parameters) don't necessarily match from before and after and they're flattened after.
  2. Is there a way for use accelerate library to employ model parallelism when using multiple models like in the above scenario of training parts of stable diffusion?

I know there's deepspeed, but https://github.com/huggingface/accelerate/issues/253#issuecomment-1253231210 means I'd have to try and wrap everything into one model engine, so I was wondering if accelerate could handle it for me.

muellerzr commented 8 months ago

cc @pacman100

theoden8 commented 7 months ago

Wrapping the model into a pytorch module doesn't help:

pipeline_accelerate_fsdp_wrapped.py ```python #!/usr/bin/env python3 import os import sys import random import typing import numpy as np import tqdm.auto as tqdm import torch import torchvision import datasets import transformers import diffusers import accelerate import xformers import packaging if __name__ == '__main__': accelerator = accelerate.Accelerator() accelerator.print('fsdp', accelerator.local_process_index) #assert accelerator.world_size <= torch.cuda.device_count(), f'{accelerator.local_process_index}: world_size={accelerator.world_size} > device_count={torch.cuda.device_count()}' sd_model = 'runwayml/stable-diffusion-v1-5' output_dir = 'output_ct_simple' pipeline = diffusers.StableDiffusionPipeline.from_pretrained( sd_model, torch_dtype=torch.float32, device_map=None, safety_checker=None, ) feature_extrator = pipeline.feature_extractor vae = pipeline.vae unet = pipeline.unet tokenizer = pipeline.tokenizer#.to(device=device) text_encoder = pipeline.text_encoder noise_scheduler = pipeline.scheduler #vae.requires_grad_(True) #unet.requires_grad_(False) #text_encoder.requires_grad_(True) #tokenizer.requires_grad_(False) #print('after loading SDv1.5') #print_gpu_memory_usage() assert packaging.version.parse(xformers.__version__) != packaging.version.parse("0.0.16") unet.enable_xformers_memory_efficient_attention() dataset = datasets.load_dataset('fusing/fill50k', None, cache_dir=None) column_names = dataset['train'].column_names image_column = 'image' caption_column = 'text' conditioning_image_column = 'conditioning_image' enabled_fp16 = (accelerator.mixed_precision == 'fp16') class PipelineWrapper(torch.nn.Module): def __init__(self, vae, unet, tokenizer, text_encoder, noise_scheduler) -> None: super().__init__() self.vae, self.unet, self.tokenizer, self.text_encoder, self.noise_scheduler = \ vae, unet, tokenizer, text_encoder, noise_scheduler def forward(self, batch): pixel_values = batch["pixel_values"] if enabled_fp16: pixel_values = pixel_values.half() latents = vae.encode(pixel_values).latent_dist.sample() latents = latents * vae.config.scaling_factor #print('latents', latents) # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning input_ids = batch['input_ids']#.to(device=text_encoder.device) encoder_hidden_states = text_encoder(input_ids)[0]#.float() #print('encoder hidden states', encoder_hidden_states) #controlnet_image = batch["conditioning_pixel_values"]#.float() #print('controlnet image', controlnet_image) pred = unet( noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, ).sample if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") loss = torch.nn.functional.mse_loss(pred, target, reduction="mean") return loss model_wrapper = PipelineWrapper( vae=vae, unet=unet, tokenizer=tokenizer, text_encoder=text_encoder, noise_scheduler=noise_scheduler, ) def tokenize_captions(examples, is_train=True): proportion_empty_prompts = .0 captions = [] for caption in examples[caption_column]: if random.random() < proportion_empty_prompts: captions.append("") elif isinstance(caption, str): captions.append(caption) elif isinstance(caption, (list, np.ndarray)): # take a random caption if there are multiple captions.append(random.choice(caption) if is_train else caption[0]) else: raise ValueError( f"Caption column `{caption_column}` should contain either strings or lists of strings." ) inputs = tokenizer( captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ) return inputs.input_ids train_batch_size = 4 resolution = 128 image_transforms = torchvision.transforms.Compose([ torchvision.transforms.Resize(resolution, interpolation=torchvision.transforms.InterpolationMode.BILINEAR), torchvision.transforms.CenterCrop(resolution), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize([0.5], [0.5]), ]) conditioning_image_transforms = torchvision.transforms.Compose([ torchvision.transforms.Resize(resolution, interpolation=torchvision.transforms.InterpolationMode.BILINEAR), torchvision.transforms.CenterCrop(resolution), torchvision.transforms.ToTensor(), ]) def preprocess_train(examples): images = [image.convert("RGB") for image in examples[image_column]] images = [image_transforms(image) for image in images] conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]] conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images] examples["pixel_values"] = images examples["conditioning_pixel_values"] = conditioning_images examples["input_ids"] = tokenize_captions(examples) return examples def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples]) conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float() input_ids = torch.stack([example["input_ids"] for example in examples]) return { "pixel_values": pixel_values, "conditioning_pixel_values": conditioning_pixel_values, "input_ids": input_ids, } train_dataset = dataset["train"].with_transform(preprocess_train) num_workers = 0 # main process train_dataloader = torch.utils.data.DataLoader( train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=train_batch_size, num_workers=num_workers, ) learning_rate = 1e-4 optimizer = torch.optim.Adam( model_wrapper.unet.parameters(), lr=learning_rate, betas=(.9, .999), weight_decay=1e-2, eps=1e-8, ) lr_warmup_steps = 500 lr_num_cycles = 1 lr_power = 1. num_processes = 1 epochs = 10 lr_scheduler = diffusers.optimization.get_scheduler( 'constant', optimizer=optimizer, num_warmup_steps=lr_warmup_steps, num_training_steps=epochs * len(train_dataloader), num_cycles=lr_num_cycles, power=lr_power, ) model_wrapper.train() model_wrapper = accelerator.prepare(model_wrapper) optimizer, lr_scheduler, train_dataloader = accelerator.prepare( optimizer, lr_scheduler, train_dataloader ) pbar = tqdm.trange(epochs * len(train_dataset), initial=0, desc='steps', leave=True) for epoch in range(epochs): for batch in train_dataloader: loss = model_wrapper(batch) optimizer.zero_grad() accelerator.backward(loss) optimizer.step() lr_scheduler.step() pbar.set_postfix(loss=loss.detach().cpu().item()) pbar.update(len(pixel_values)) ```
Output ``` + CUDA_HOME=/usr/local/cuda-11.1 + OMP_NUM_THREADS=1 + CUDA_VISIBLE_DEVICES=0,1,2,3 + python -m accelerate.commands.launch --config_file ./accelerate_fsdp.yaml pipeline_accelerate_fsdp_wrapped.py Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher. fsdp 0 Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 6.40it/s] You have disabled the safety checker for by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 . Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 6.11it/s] You have disabled the safety checker for by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 . Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:01<00:00, 5.61it/s] You have disabled the safety checker for by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 . Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:01<00:00, 4.83it/s] You have disabled the safety checker for by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 . steps: 0%| | 0/500000 [00:00 Traceback (most recent call last): File "$PWD/pipeline_accelerate_fsdp_wrapped.py", line 260, in loss = model_wrapper(batch) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl loss = model_wrapper(batch) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs)return self._call_impl(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward return forward_call(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward Traceback (most recent call last): File "$PWD/pipeline_accelerate_fsdp_wrapped.py", line 260, in loss = model_wrapper(batch) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl output = self._fsdp_wrapped_module(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl output = self._fsdp_wrapped_module(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl Traceback (most recent call last): File "$PWD/pipeline_accelerate_fsdp_wrapped.py", line 260, in return self._call_impl(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return self._call_impl(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return self._call_impl(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward loss = model_wrapper(batch) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return forward_call(*args, **kwargs) File "$PWD/pipeline_accelerate_fsdp_wrapped.py", line 122, in forward output = self._fsdp_wrapped_module(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl pred = unet( return forward_call(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl File "$PWD/pipeline_accelerate_fsdp_wrapped.py", line 122, in forward pred = unet( File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return self._call_impl(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return self._call_impl(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return self._call_impl(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "$PWD/pipeline_accelerate_fsdp_wrapped.py", line 122, in forward return forward_call(*args, **kwargs) return forward_call(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/diffusers/models/unet_2d_condition.py", line 969, in forward File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward pred = unet( File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return forward_call(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/diffusers/models/unet_2d_condition.py", line 969, in forward output = self._fsdp_wrapped_module(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/diffusers/models/unet_2d_condition.py", line 969, in forward return self._call_impl(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl emb = self.time_embedding(t_emb, timestep_cond) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl emb = self.time_embedding(t_emb, timestep_cond) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl emb = self.time_embedding(t_emb, timestep_cond) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return forward_call(*args, **kwargs) File "$PWD/pipeline_accelerate_fsdp_wrapped.py", line 122, in forward pred = unet( File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return self._call_impl(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return self._call_impl(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs)return forward_call(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/diffusers/models/embeddings.py", line 228, in forward File "$CONDA_PREFIX/lib/python3.10/site-packages/diffusers/models/embeddings.py", line 228, in forward return self._call_impl(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/diffusers/models/embeddings.py", line 228, in forward return forward_call(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/diffusers/models/unet_2d_condition.py", line 969, in forward sample = self.linear_1(sample) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl sample = self.linear_1(sample)sample = self.linear_1(sample) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl emb = self.time_embedding(t_emb, timestep_cond) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) return self._call_impl(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return self._call_impl(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) return forward_call(*args, **kwargs) return forward_call(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward return self._call_impl(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) return F.linear(input, self.weight, self.bias) return F.linear(input, self.weight, self.bias) File "$CONDA_PREFIX/lib/python3.10/site-packages/diffusers/models/embeddings.py", line 228, in forward RuntimeErrorRuntimeError: mat2 must be a matrix, got 1-D tensor: mat2 must be a matrix, got 1-D tensor return F.linear(input, self.weight, self.bias) RuntimeError: mat2 must be a matrix, got 1-D tensor sample = self.linear_1(sample) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward return F.linear(input, self.weight, self.bias) RuntimeError: mat2 must be a matrix, got 1-D tensor steps: 0%| | 0/500000 [00:04 main() File "$CONDA_PREFIX/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1023, in main launch_command(args) File "$CONDA_PREFIX/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1004, in launch_command multi_gpu_launcher(args) File "$CONDA_PREFIX/lib/python3.10/site-packages/accelerate/commands/launch.py", line 666, in multi_gpu_launcher distrib_run.run(args) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/run.py", line 797, in run elastic_launch( File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 134, in __call__ return launch_agent(self._config, self._entrypoint, list(args)) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent raise ChildFailedError( torch.distributed.elastic.multiprocessing.errors.ChildFailedError: ============================================================ pipeline_accelerate_fsdp_wrapped.py FAILED ------------------------------------------------------------ Failures: [1]: time : 2024-01-24_06:20:55 host : ccp3.clostra.com rank : 1 (local_rank: 1) exitcode : 1 (pid: 203228) error_file: traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html [2]: time : 2024-01-24_06:20:55 host : ccp3.clostra.com rank : 2 (local_rank: 2) exitcode : 1 (pid: 203229) error_file: traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html [3]: time : 2024-01-24_06:20:55 host : ccp3.clostra.com rank : 3 (local_rank: 3) exitcode : 1 (pid: 203230) error_file: traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html ------------------------------------------------------------ Root Cause (first observed failure): [0]: time : 2024-01-24_06:20:55 host : ccp3.clostra.com rank : 0 (local_rank: 0) exitcode : 1 (pid: 203227) error_file: traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html ============================================================ ```

But it fails at unet call rather than vae.encode.

theoden8 commented 7 months ago

If I use FSDP without accelerate it fails at loss.backward (which is further than accelerate), but the problem seems to be related:

pipeline_fsdp.py ```python #!/usr/bin/env python3 import os import sys import random import typing import functools import numpy as np import tqdm.auto as tqdm import torch import torch.distributed.fsdp import torchvision import datasets import transformers import diffusers import xformers import packaging def fsdp_setup(rank, world_size): os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size) def fsdp_cleanup(): torch.distributed.destroy_process_group() def fsdp_main(rank, world_size, args): print('fsdp rank', rank) fsdp_setup(rank, world_size) torch.cuda.set_device(rank) device = torch.cuda.current_device() #device = f'cuda:{rank}' auto_wrap_policy_size = functools.partial( torch.distributed.fsdp.wrap.size_based_auto_wrap_policy, min_num_params=100) sd_model = 'runwayml/stable-diffusion-v1-5' output_dir = 'output_ct_simple' pipeline = diffusers.StableDiffusionPipeline.from_pretrained( sd_model, torch_dtype=torch.float32, device_map=None, safety_checker=None, ) feature_extrator = pipeline.feature_extractor vae = pipeline.vae unet = pipeline.unet tokenizer = pipeline.tokenizer#.to(device=device) text_encoder = pipeline.text_encoder noise_scheduler = pipeline.scheduler assert packaging.version.parse(xformers.__version__) != packaging.version.parse("0.0.16") unet.enable_xformers_memory_efficient_attention() dataset = datasets.load_dataset('fusing/fill50k', None, cache_dir=None) column_names = dataset['train'].column_names image_column = 'image' caption_column = 'text' conditioning_image_column = 'conditioning_image' enabled_fp16 = False class PipelineWrapper(torch.nn.Module): def __init__(self, vae, unet, tokenizer, text_encoder, noise_scheduler) -> None: super().__init__() self.vae, self.unet, self.tokenizer, self.text_encoder, self.noise_scheduler = \ vae, unet, tokenizer, text_encoder, noise_scheduler def forward(self, batch): pixel_values = batch["pixel_values"] if enabled_fp16: pixel_values = pixel_values.half() latents = vae.encode(pixel_values).latent_dist.sample() latents = latents * vae.config.scaling_factor #print('latents', latents) # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning input_ids = batch['input_ids']#.to(device=text_encoder.device) encoder_hidden_states = text_encoder(input_ids)[0]#.float() #print('encoder hidden states', encoder_hidden_states) #controlnet_image = batch["conditioning_pixel_values"]#.float() #print('controlnet image', controlnet_image) pred = unet( noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, ).sample if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") loss = torch.nn.functional.mse_loss(pred, target, reduction="mean") return loss model_wrapper = PipelineWrapper( vae=vae, unet=unet, tokenizer=tokenizer, text_encoder=text_encoder, noise_scheduler=noise_scheduler, ) def tokenize_captions(examples, is_train=True): proportion_empty_prompts = .0 captions = [] for caption in examples[caption_column]: if random.random() < proportion_empty_prompts: captions.append("") elif isinstance(caption, str): captions.append(caption) elif isinstance(caption, (list, np.ndarray)): # take a random caption if there are multiple captions.append(random.choice(caption) if is_train else caption[0]) else: raise ValueError( f"Caption column `{caption_column}` should contain either strings or lists of strings." ) inputs = tokenizer( captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ) return inputs.input_ids train_batch_size = 4 resolution = 128 image_transforms = torchvision.transforms.Compose([ torchvision.transforms.Resize(resolution, interpolation=torchvision.transforms.InterpolationMode.BILINEAR), torchvision.transforms.CenterCrop(resolution), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize([0.5], [0.5]), ]) conditioning_image_transforms = torchvision.transforms.Compose([ torchvision.transforms.Resize(resolution, interpolation=torchvision.transforms.InterpolationMode.BILINEAR), torchvision.transforms.CenterCrop(resolution), torchvision.transforms.ToTensor(), ]) def preprocess_train(examples): images = [image.convert("RGB") for image in examples[image_column]] images = [image_transforms(image) for image in images] conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]] conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images] examples["pixel_values"] = images examples["conditioning_pixel_values"] = conditioning_images examples["input_ids"] = tokenize_captions(examples) return examples def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples]) conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float() input_ids = torch.stack([example["input_ids"] for example in examples]) return { "pixel_values": pixel_values, "conditioning_pixel_values": conditioning_pixel_values, "input_ids": input_ids, } train_dataset = dataset["train"].with_transform(preprocess_train) num_workers = 0 # main process train_dataloader = torch.utils.data.DataLoader( train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=train_batch_size, num_workers=num_workers, ) learning_rate = 1e-4 #trainable_params = [p for p in list(unet.parameters()) + list(vae.parameters()) + list(text_encoder.parameters()) if p.requires_grad] optimizer = torch.optim.Adam( model_wrapper.parameters(), #model_wrapper.unet.parameters(), lr=learning_rate, betas=(.9, .999), weight_decay=1e-2, eps=1e-8, ) lr_warmup_steps = 500 lr_num_cycles = 1 lr_power = 1. num_processes = 1 epochs = 10 lr_scheduler = diffusers.optimization.get_scheduler( 'constant', optimizer=optimizer, num_warmup_steps=lr_warmup_steps, num_training_steps=epochs * len(train_dataloader), num_cycles=lr_num_cycles, power=lr_power, ) # def param_init_fn(module: torch.nn.Module) -> None: # module = module.to('cpu') model_wrapper = torch.distributed.fsdp.FullyShardedDataParallel(model_wrapper, cpu_offload=torch.distributed.fsdp.CPUOffload(offload_params=True), sharding_strategy=torch.distributed.fsdp.ShardingStrategy.SHARD_GRAD_OP, auto_wrap_policy=auto_wrap_policy_size, sync_module_states=True, #param_init_fn=param_init_fn, device_id=f'cuda:{rank}', #backward_prefetch=torch.distributed.fsdp.BackwardPrefetch.NO_PREFETCH, backward_prefetch=None, forward_prefetch=False, )#.to(device=device) model_wrapper.train() if rank == 0: pbar = tqdm.trange(epochs * len(train_dataset), initial=0, desc='steps', leave=True) first_step = True for epoch in range(epochs): for batch in train_dataloader: loss = model_wrapper(batch) optimizer.zero_grad() loss.backward() optimizer.step(loss) lr_scheduler.step() if rank == 0: pbar.set_postfix(loss=loss.detach().cpu().item()) pbar.update(len(pixel_values)) if first_step: first_step = False if rank == 0: pbar.close() torch.distributed.barrier() fsdp_cleanup() if __name__ == '__main__': world_size = 2 args = [] torch.multiprocessing.spawn(fsdp_main, args=(world_size, args), nprocs=world_size, join=True) ```
run_pp_fsdp.sh ```bash #!/usr/bin/env bash set -e eval "$("${HOME}"/.miniconda3/bin/conda shell.bash hook)" conda activate hfkernel set -x CUDA_VISIBLE_DEVICES=0,1 \ python3 \ pipeline_fsdp.py ```
Output ``` + python3 pipeline_fsdp.py $CONDA_PREFIX/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( $CONDA_PREFIX/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( $CONDA_PREFIX/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( $CONDA_PREFIX/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( $CONDA_PREFIX/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( $CONDA_PREFIX/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( $CONDA_PREFIX/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( $CONDA_PREFIX/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( $CONDA_PREFIX/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( $CONDA_PREFIX/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( $CONDA_PREFIX/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( fsdp rank 7 fsdp rank 8 fsdp rank 3 fsdp rank 0 fsdp rank 4 fsdp rank 5 fsdp rank 6 fsdp rank 1 fsdp rank 9 fsdp rank 2 $CONDA_PREFIX/lib/python3.10/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( $CONDA_PREFIX/lib/python3.10/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( $CONDA_PREFIX/lib/python3.10/site-packages/diffusers/utils/outputs.py:63: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. torch.utils._pytree._register_pytree_node( Loading pipeline components...: 33%|█████████████████████████████████████████ | 2/6 [00:00<00:00, 11.78it/s]$CONDA_PREFIX/lib/python3.10/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( $CONDA_PREFIX/lib/python3.10/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( Loading pipeline components...: 67%|██████████████████████████████████████████████████████████████████████████████████ | 4/6 [00:00<00:00, 4.68it/s]$CONDA_PREFIX/lib/python3.10/site-packages/diffusers/utils/outputs.py:63: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. torch.utils._pytree._register_pytree_node( Loading pipeline components...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:01<00:00, 6.00it/s] You have disabled the safety checker for by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 . $CONDA_PREFIX/lib/python3.10/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( $CONDA_PREFIX/lib/python3.10/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( $CONDA_PREFIX/lib/python3.10/site-packages/diffusers/utils/outputs.py:63: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. torch.utils._pytree._register_pytree_node( $CONDA_PREFIX/lib/python3.10/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( $CONDA_PREFIX/lib/python3.10/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( $CONDA_PREFIX/lib/python3.10/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( $CONDA_PREFIX/lib/python3.10/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. _torch_pytree._register_pytree_node( $CONDA_PREFIX/lib/python3.10/site-packages/diffusers/utils/outputs.py:63: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. torch.utils._pytree._register_pytree_node( $CONDA_PREFIX/lib/python3.10/site-packages/diffusers/utils/outputs.py:63: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead. torch.utils._pytree._register_pytree_node( Loading pipeline components...: 0%| | 0/6 [00:00 by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 . Loading pipeline components...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 6.31it/s] You have disabled the safety checker for by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 . Loading pipeline components...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 7.42it/s] You have disabled the safety checker for by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 . Loading pipeline components...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 6.37it/s] You have disabled the safety checker for by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 . Loading pipeline components...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 6.32it/s] You have disabled the safety checker for by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 . Loading pipeline components...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 6.09it/s] You have disabled the safety checker for by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 . Loading pipeline components...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 6.90it/s] You have disabled the safety checker for by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 . Loading pipeline components...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 6.50it/s] You have disabled the safety checker for by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 . Loading pipeline components...: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:00<00:00, 6.99it/s] You have disabled the safety checker for by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 . steps: 0%| | 0/500000 [00:00 0, "Expects storage to be allocated") File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 340, in _reshard handle.reshard(free_unsharded_flat_param) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 1705, in reshard self._free_unsharded_flat_param() File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/utils.py", line 144, in _p_assert traceback.print_stack() File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 1705, in reshard self._free_unsharded_flat_param() File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1092, in _post_backward_final_callback _catch_all_reshard(fsdp_state) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 1738, in _free_unsharded_flat_param self._check_storage_allocated(unsharded_flat_param) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1092, in _post_backward_final_callback _catch_all_reshard(fsdp_state) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 1738, in _free_unsharded_flat_param self._check_storage_allocated(unsharded_flat_param) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1134, in _catch_all_reshard _reshard(state, state._handle, free_unsharded_flat_param) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 2559, in _check_storage_allocated _p_assert(storage_size > 0, "Expects storage to be allocated") File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1134, in _catch_all_reshard _reshard(state, state._handle, free_unsharded_flat_param) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 2559, in _check_storage_allocated _p_assert(storage_size > 0, "Expects storage to be allocated") File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 340, in _reshard handle.reshard(free_unsharded_flat_param) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/utils.py", line 144, in _p_assert traceback.print_stack() File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 1705, in reshard self._free_unsharded_flat_param() File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 340, in _reshard handle.reshard(free_unsharded_flat_param) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/utils.py", line 144, in _p_assert traceback.print_stack() File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 1738, in _free_unsharded_flat_param self._check_storage_allocated(unsharded_flat_param) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 1705, in reshard self._free_unsharded_flat_param() File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 2559, in _check_storage_allocated _p_assert(storage_size > 0, "Expects storage to be allocated") File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 1738, in _free_unsharded_flat_param self._check_storage_allocated(unsharded_flat_param) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/utils.py", line 144, in _p_assert traceback.print_stack() File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 2559, in _check_storage_allocated _p_assert(storage_size > 0, "Expects storage to be allocated") File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/utils.py", line 144, in _p_assert traceback.print_stack() Got exception in the catch-all reshard for FullyShardedDataParallel( (_fsdp_wrapped_module): Conv2d(4, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ): Expects storage to be allocated Got exception in the catch-all reshard for FullyShardedDataParallel( (_fsdp_wrapped_module): Conv2d(4, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ): Expects storage to be allocated Got exception in the catch-all reshard for FullyShardedDataParallel( (_fsdp_wrapped_module): Conv2d(4, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ): Expects storage to be allocated File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1092, in _post_backward_final_callback _catch_all_reshard(fsdp_state) Got exception in the catch-all reshard for FullyShardedDataParallel( (_fsdp_wrapped_module): Conv2d(4, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ): Expects storage to be allocated File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1092, in _post_backward_final_callback _catch_all_reshard(fsdp_state) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1134, in _catch_all_reshard _reshard(state, state._handle, free_unsharded_flat_param) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) Got exception in the catch-all reshard for FullyShardedDataParallel( (_fsdp_wrapped_module): Conv2d(4, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ): Expects storage to be allocated File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1134, in _catch_all_reshard _reshard(state, state._handle, free_unsharded_flat_param) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 340, in _reshard handle.reshard(free_unsharded_flat_param) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1092, in _post_backward_final_callback _catch_all_reshard(fsdp_state) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 340, in _reshard handle.reshard(free_unsharded_flat_param) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 1705, in reshard self._free_unsharded_flat_param() File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1136, in _catch_all_reshard _p_assert( File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 1705, in reshard self._free_unsharded_flat_param() File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/utils.py", line 144, in _p_assert traceback.print_stack() File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 1738, in _free_unsharded_flat_param self._check_storage_allocated(unsharded_flat_param) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1092, in _post_backward_final_callback _catch_all_reshard(fsdp_state) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 1738, in _free_unsharded_flat_param self._check_storage_allocated(unsharded_flat_param) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1092, in _post_backward_final_callback _catch_all_reshard(fsdp_state) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 2559, in _check_storage_allocated _p_assert(storage_size > 0, "Expects storage to be allocated") File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1136, in _catch_all_reshard _p_assert( File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 2559, in _check_storage_allocated _p_assert(storage_size > 0, "Expects storage to be allocated") File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1136, in _catch_all_reshard _p_assert( File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/utils.py", line 144, in _p_assert traceback.print_stack() File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/utils.py", line 144, in _p_assert traceback.print_stack() File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/utils.py", line 144, in _p_assert traceback.print_stack() File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1092, in _post_backward_final_callback _catch_all_reshard(fsdp_state) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/utils.py", line 144, in _p_assert traceback.print_stack() File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1092, in _post_backward_final_callback _catch_all_reshard(fsdp_state) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1136, in _catch_all_reshard _p_assert( File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1136, in _catch_all_reshard _p_assert( File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/utils.py", line 144, in _p_assert traceback.print_stack() File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/utils.py", line 144, in _p_assert traceback.print_stack() Got exception in the catch-all reshard for FullyShardedDataParallel( (_fsdp_wrapped_module): Conv2d(4, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ): Expects storage to be allocated Got exception in the catch-all reshard for FullyShardedDataParallel( (_fsdp_wrapped_module): Conv2d(4, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ): Expects storage to be allocated File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1092, in _post_backward_final_callback _catch_all_reshard(fsdp_state) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1092, in _post_backward_final_callback _catch_all_reshard(fsdp_state) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1136, in _catch_all_reshard _p_assert( File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1134, in _catch_all_reshard _reshard(state, state._handle, free_unsharded_flat_param) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/utils.py", line 144, in _p_assert traceback.print_stack() File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 340, in _reshard handle.reshard(free_unsharded_flat_param) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 1705, in reshard self._free_unsharded_flat_param() File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 1738, in _free_unsharded_flat_param self._check_storage_allocated(unsharded_flat_param) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1092, in _post_backward_final_callback _catch_all_reshard(fsdp_state) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 2559, in _check_storage_allocated _p_assert(storage_size > 0, "Expects storage to be allocated") File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1136, in _catch_all_reshard _p_assert( File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/utils.py", line 144, in _p_assert traceback.print_stack() File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/utils.py", line 144, in _p_assert traceback.print_stack() File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) Expects storage to be allocated File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1092, in _post_backward_final_callback _catch_all_reshard(fsdp_state) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1134, in _catch_all_reshard _reshard(state, state._handle, free_unsharded_flat_param) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 340, in _reshard handle.reshard(free_unsharded_flat_param) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 1705, in reshard self._free_unsharded_flat_param() File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 1738, in _free_unsharded_flat_param self._check_storage_allocated(unsharded_flat_param) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 2559, in _check_storage_allocated _p_assert(storage_size > 0, "Expects storage to be allocated") File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/utils.py", line 144, in _p_assert traceback.print_stack() Got exception in the catch-all reshard for FullyShardedDataParallel( (_fsdp_wrapped_module): Conv2d(4, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ): Expects storage to be allocated File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1092, in _post_backward_final_callback _catch_all_reshard(fsdp_state) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1136, in _catch_all_reshard _p_assert( Got exception in the catch-all reshard for FullyShardedDataParallel( (_fsdp_wrapped_module): Conv2d(4, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ): Expects storage to be allocated File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/utils.py", line 144, in _p_assert traceback.print_stack() File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1092, in _post_backward_final_callback _catch_all_reshard(fsdp_state) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1136, in _catch_all_reshard _p_assert( File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/utils.py", line 144, in _p_assert traceback.print_stack() File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1092, in _post_backward_final_callback _catch_all_reshard(fsdp_state) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1134, in _catch_all_reshard _reshard(state, state._handle, free_unsharded_flat_param) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 340, in _reshard handle.reshard(free_unsharded_flat_param) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 1705, in reshard self._free_unsharded_flat_param() File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 1738, in _free_unsharded_flat_param self._check_storage_allocated(unsharded_flat_param) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 2559, in _check_storage_allocated _p_assert(storage_size > 0, "Expects storage to be allocated") File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/utils.py", line 144, in _p_assert traceback.print_stack() Got exception in the catch-all reshard for FullyShardedDataParallel( (_fsdp_wrapped_module): Conv2d(4, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ): Expects storage to be allocated File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1092, in _post_backward_final_callback _catch_all_reshard(fsdp_state) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1136, in _catch_all_reshard _p_assert( File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/utils.py", line 144, in _p_assert traceback.print_stack() steps: 0%| | 0/500000 [00:23 torch.multiprocessing.spawn(fsdp_main, File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 241, in spawn return start_processes(fn, args, nprocs, join, daemon, start_method="spawn") File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 197, in start_processes while not context.join(): File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 158, in join raise ProcessRaisedException(msg, error_index, failed_process.pid) torch.multiprocessing.spawn.ProcessRaisedException: -- Process 4 terminated with the following error: Traceback (most recent call last): File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 68, in _wrap fn(i, *args) File "$PWD/pipeline_fsdp.py", line 286, in fsdp_main loss.backward() File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward torch.autograd.backward( File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1092, in _post_backward_final_callback _catch_all_reshard(fsdp_state) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1141, in _catch_all_reshard raise e File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 1134, in _catch_all_reshard _reshard(state, state._handle, free_unsharded_flat_param) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 340, in _reshard handle.reshard(free_unsharded_flat_param) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 1705, in reshard self._free_unsharded_flat_param() File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 1738, in _free_unsharded_flat_param self._check_storage_allocated(unsharded_flat_param) File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 2559, in _check_storage_allocated _p_assert(storage_size > 0, "Expects storage to be allocated") File "$CONDA_PREFIX/lib/python3.10/site-packages/torch/distributed/utils.py", line 146, in _p_assert raise AssertionError(s) AssertionError: Expects storage to be allocated ```
github-actions[bot] commented 6 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

theoden8 commented 6 months ago

Any update on this? Could we try and reopen? @pacman100

github-actions[bot] commented 5 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

theoden8 commented 4 months ago

:/

samsara-ku commented 2 months ago

Is there any solution for this problem @theoden8 ? I stuck with this problem :(

theoden8 commented 2 months ago

I stopped investigating some time ago

samsara-ku commented 2 months ago

Thanks for replying :+1:

drimeF0 commented 2 months ago

same error, my code was supposed to train vae on 8 v2 TPU using fsdp image

github-actions[bot] commented 1 month ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.