huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.42k stars 5.27k forks source link

device_map="auto" in AutoPipelineForText2Image raises error #6240

Open Nidhogg-lyz opened 9 months ago

Nidhogg-lyz commented 9 months ago

Describe the bug

Thanks for ur contributions! I'm using AutoPipelineForText2Image and set device_map="auto" when initializing, but the program raises Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! when inferring. Even I manually put the essential components on target devices(text_encoder and unet on cuda:0, vae on cuda:1) and built a pipeline through StableDiffusionPipeline this keeps happening. Should I rewrite the pipeline process in my code?

Reproduction

This is a simple script for reproduction:

from diffusers import AutoPipelineForText2Image

pipeline = AutoPipelineForText2Image.from_pretrained(
    "stabilityai/stable-diffusion-2-1", device_map="auto"
)

generator = torch.Generator().manual_seed(31)
image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", generator=generator).images[0]

Logs

Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument tensors in method wrapper_CUDA_cat)

System Info

I'm using python=3.11, diffusers=0.24.0, accelerate=0.25.0, torch=2.1.2 with CUDA=11.8

Who can help?

@yiyixuxu @DN6

sayakpaul commented 9 months ago

Should I rewrite the pipeline process in my code?

I think so yes. Pinging @muellerzr regarding advice on this.

Note that if you are planning to leverage multiple GPUs during inference, we have some documentation:

If you're in a multi-GPU environment and want to leverage just a single GPU, I think it's just best to append CUDA_VISIBLE_DEVICES=<DEVICE_ID> before launching your Python program.

muellerzr commented 9 months ago

cc @SunMarc

SunMarc commented 9 months ago

Hi @Nidhogg-lyz, thanks for reporting. This happens because device_map = "auto" is not fully supported in diffusers. Hence, some modules are split when it should not be the case no_split_module_classes. For the meantime, I advise you to not use deivce_map="auto". Otherwise, even if you try to move the model afterwards, you will get errors. @sayakpaul

patrickvonplaten commented 9 months ago

I think we could start thinking about how to best support device_map="auto" for diffusers. @SunMarc the difference between diffusers and transformers is that diffusers pipelines are not torch.nn.Module objects, but instead define a chain of multiple torch.nn.Module objects. This chain is defined in every pipeline here: https://github.com/huggingface/diffusers/blob/6e123688dc63ae1b49e085d1c228f935a7e187fd/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L154 which we currently also use for cpu offloading. I think it could be relatively easy to write a generic function in https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_utils.py that makes sure that components are correctly moved to different GPU devices in case the user passes device_map="auto"

patrickvonplaten commented 9 months ago

This could be a cool feature to work on in case someone is interested

Nidhogg-lyz commented 9 months ago

Hi @Nidhogg-lyz, thanks for reporting. This happens because device_map = "auto" is not fully supported in diffusers. Hence, some modules are split when it should not be the case no_split_module_classes. For the meantime, I advise you to not use deivce_map="auto". Otherwise, even if you try to move the model afterwards, you will get errors. @sayakpaul

Thanks for the advice! I've tried to rewrite the pipeline myself in which I manually put the components on target devices and that works fine for me.

sayakpaul commented 9 months ago

Thanks for the advice! I've tried to rewrite the pipeline myself in which I manually put the components on target devices and that works fine for me.

Feel free to show us your implementation. That would be beneficial for our own learning too!

Nidhogg-lyz commented 9 months ago

@sayakpaul Since I loaded each component and put them on their target device manually, I just inherited from the StableDiffusionPipeline and changed a few lines in __call__ to place each input on the correct device. But this approach assumes that each component is no longer divided into smaller pieces on different devices, which is not so "auto". Hope this can be helpful! Here are the codes of the modified __call__ function:

    def __call__(self, prompt = None, height = None, width = None, num_inference_steps = 10, timesteps = None, guidance_scale = 7.5,
                 negative_prompt = None, num_images_per_prompt = 1, eta = 0, generator = None, latents = None, prompt_embeds = None,
                 negative_prompt_embeds = None, output_type = "pil", return_dict = True, callback = None, callback_steps = 1,
                 cross_attention_kwargs = None, guidance_rescale = 0):

        # 0. Default height and width to unet
        height = height or self.unet.config.sample_size * self.vae_scale_factor
        width = width or self.unet.config.sample_size * self.vae_scale_factor
        # to deal with lora scaling and other possible forward hooks

        # 1. Check inputs. Raise error if not correct
        self.check_inputs(
            prompt,
            height,
            width,
            callback_steps,
            negative_prompt,
            prompt_embeds,
            negative_prompt_embeds,
        )

        self._guidance_scale = guidance_scale
        self._guidance_rescale = guidance_rescale
        self._cross_attention_kwargs = cross_attention_kwargs

        # 2. Define call parameters
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        ### Because the components may be not on the same device, self._execution_divice cannot be used
        # device = self._execution_device

        # 3. Encode input prompt
        lora_scale = (
            self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
        )

        ### get device of prompt
        if self.text_encoder is None:
            prompt_device = self.unet.device
        else:
            prompt_device = self.text_encoder.device

        prompt_embeds, negative_prompt_embeds = self.encode_prompt(
            prompt,
            prompt_device,
            num_images_per_prompt,
            self.do_classifier_free_guidance,
            negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            lora_scale=lora_scale,
            clip_skip=None,
        )

        # For classifier free guidance, we need to do two forward passes.
        # Here we concatenate the unconditional and text embeddings into a single batch
        # to avoid doing two forward passes
        if self.do_classifier_free_guidance:
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

        ### put prompt_embeds to unet device
        prompt_embeds = prompt_embeds.to(self.unet.device)

        # 4. Prepare timesteps

        ### get timestep device
        timestep_device = self.unet.device

        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timestep_device, timesteps)

        # 5. Prepare latent variables

        ### get latents device
        latents_device = self.unet.device

        num_channels_latents = self.unet.config.in_channels
        latents = self.prepare_latents(
            batch_size * num_images_per_prompt,
            num_channels_latents,
            height,
            width,
            prompt_embeds.dtype,
            latents_device,
            generator,
            latents,
        )

        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

        # 6.1 Add image embeds for IP-Adapter
        added_cond_kwargs =  None

        # 6.2 Optionally get Guidance Scale Embedding
        timestep_cond = None
        if self.unet.config.time_cond_proj_dim is not None:
            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
            timestep_cond = self.get_guidance_scale_embedding(
                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
            ).to(device=self.unet.device, dtype=latents.dtype) ### put timestep to unet device

        # 7. Denoising loop
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        self._num_timesteps = len(timesteps)
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                # expand the latents if we are doing classifier free guidance
                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                # predict the noise residual
                noise_pred = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    timestep_cond=timestep_cond,
                    cross_attention_kwargs=self.cross_attention_kwargs,
                    added_cond_kwargs=added_cond_kwargs,
                    return_dict=False,
                )[0]

                # perform guidance
                if self.do_classifier_free_guidance:
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)

                if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)

                # compute the previous noisy sample x_t -> x_t-1
                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

                # call the callback, if provided
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        step_idx = i // getattr(self.scheduler, "order", 1)
                        callback(step_idx, t, latents)

        if not output_type == "latent":
            ### put latents to vae device
            latents = latents.to(self.vae.device)
            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
                0
            ]
            image, has_nsfw_concept = self.run_safety_checker(image, self.vae.device, prompt_embeds.dtype)
        else:
            image = latents
            has_nsfw_concept = None

        if has_nsfw_concept is None:
            do_denormalize = [True] * image.shape[0]
        else:
            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

        # Offload all models
        self.maybe_free_model_hooks()

        if not return_dict:
            return (image, has_nsfw_concept)

        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
sayakpaul commented 9 months ago

Ah I see. I see that you have separated all the devices manually like prompt_device, unet_device, and so on.

sayakpaul commented 9 months ago

@patrickvonplaten I will give it a try in the coming days. Have assigned myself.

github-actions[bot] commented 8 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

xraywu commented 7 months ago

Hi, has this issue been resolved? I upgraded diffusers to 0.26.3 but the same error still happens.

My test codes:

from diffusers import AutoPipelineForText2Image
import torch

pipeline = AutoPipelineForText2Image.from_pretrained(
        "/data/llm_modles/modelscope/hub/AI-ModelScope/sdxl-turbo", torch_dtype=torch.float16, use_safetensors=True, device_map="auto"
).to("cuda")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"

image = pipeline(prompt, num_inference_steps=25).images[0]
sayakpaul commented 7 months ago

No it hasn’t been. That is why I reopened it.

github-actions[bot] commented 6 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

sayakpaul commented 6 months ago

Not stale. Is being actively worked on https://github.com/huggingface/diffusers/pull/6857

github-actions[bot] commented 5 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.