huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
26.13k stars 5.38k forks source link

bug in load lora weights when add align_device_hook to model #7539

Open zhangvia opened 7 months ago

zhangvia commented 7 months ago

Describe the bug

i noticed that when i add align_device_hook to module in pipeline manually, then load_lora_weights function will enable the sequential cpu offload. so i dig deeper and find that load_lora_weights function use _optionally_disable_offloading function to decide whether to sequentially cpu offload. this use _optionally_disable_offloading function was:

    def _optionally_disable_offloading(cls, _pipeline):
        """
        Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.

        Args:
            _pipeline (`DiffusionPipeline`):
                The pipeline to disable offloading for.

        Returns:
            tuple:
                A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
        """
        is_model_cpu_offload = False
        is_sequential_cpu_offload = False

        if _pipeline is not None:
            for _, component in _pipeline.components.items():
                if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
                    if not is_model_cpu_offload:
                        is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
                    if not is_sequential_cpu_offload:
                        is_sequential_cpu_offload = isinstance(component._hf_hook, AlignDevicesHook)

                    logger.info(
                        "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
                    )
                    remove_hook_from_module(component, recurse=is_sequential_cpu_offload)

        return (is_model_cpu_offload, is_sequential_cpu_offload)

so i was curious that why is_sequential_cpu_offload = True when component has AlignDevicesHook? Shouldn't it be True only when the component device is CPU?

Reproduction

from diffusers import StableDiffusionControlNetImg2ImgPipeline,ControlNetModel
from accelerate.hooks import attach_align_device_hook_on_blocks

pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained("/media/74nvme/checkpoints/diffusers_models/stable-diffusion-v1-5/",controlnet=[controlnet1,controlnet2],torch_dtype=torch.float16).to('cuda:0')

module_names, _ = pipe._get_signature_keys(pipe)
modules = [getattr(pipe, n, None) for n in module_names]
module_names = [name for m,name in zip(modules,module_names) if isinstance(m, torch.nn.Module)]
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
print(module_names)

for module,name in zip(modules,module_names):
    if name == 'unet' or name == 'controlnet':
        module.to('cuda:0')
        attach_align_device_hook_on_blocks(
            module,
            execution_device=module.device,
        )
    else:
        module.to('cuda:1')
        attach_align_device_hook_on_blocks(
            module,
            execution_device=module.device,
        )

and then pipe.load_lora_weights(lora_weights_path) will change all component device

Logs

No response

System Info

diffusers:0.25.1 torch:2.2.0+cu118

Who can help?

No response

yiyixuxu commented 7 months ago

Hi:

so i was curious that why is_sequential_cpu_offload = True when component has AlignDevicesHook? Shouldn't it be True only when the component device is CPU?

by "sequentiaal_cpu_offload" we are referring to the enable_sequential_cpu_offload method that you can call on our pipelines

https://github.com/huggingface/diffusers/blob/19ab04ff56cb4389f6a988b985b76da48a75ada9/src/diffusers/pipelines/pipeline_utils.py#L1037

I don't think this is a bug, no?

zhangvia commented 7 months ago

i think the pr #6857 will add AlignDevicesHook to every model in the pipeline. so if i use that new feature and load lora weights simultaneously, enable_sequential_cpu_offload will be called in load_lora_weights method. I wanna know if my understanding is correct. thank you for explanantion

sayakpaul commented 7 months ago

i think the pr #6857 will add AlignDevicesHook to every model in the pipeline. so if i use that new feature and load lora weights simultaneously, enable_sequential_cpu_offload will be called in load_lora_weights method. I wanna know if my understanding is correct. thank you for explanantion

The PR addresses a different problem, not sure how it's related.

zhangvia commented 7 months ago

The PR addresses a different problem

yes, but that pr need AlignDevicesHook to prepare the model input before call model forward method. and if i call the load_lora_weights method. finally, the enable_sequential_cpu_offload() method will be called because of the AlignDevicesHook. the chain is like this

load_lora_weights -> load_lora_into_unet -> _pipeline.enable_sequential_cpu_offload()

the part of load_lora_into_unet code is below:

            # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
            # otherwise loading LoRA weights will lead to an error
            is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)

            inject_adapter_in_model(lora_config, unet, adapter_name=adapter_name)
            incompatible_keys = set_peft_model_state_dict(unet, state_dict, adapter_name)

            if incompatible_keys is not None:
                # check only for unexpected keys
                unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
                if unexpected_keys:
                    logger.warning(
                        f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
                        f" {unexpected_keys}. "
                    )

            # Offload back.
            if is_model_cpu_offload:
                _pipeline.enable_model_cpu_offload()
            elif is_sequential_cpu_offload:
                _pipeline.enable_sequential_cpu_offload()
            # Unsafe code />

and the load_lora_into_unet method will call cls._optionally_disable_offloading(_pipeline) to decide whether to call enable_sequential_cpu_offload() method

in _optionally_disable_offloading(pipeline) method, the is_sequential_cpu_offload will be set True because of the AlignDevicesHook in module. part of _optionally_disable_offloading method code is below:

            for _, component in _pipeline.components.items():
                if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
                    if not is_model_cpu_offload:
                        is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
                    if not is_sequential_cpu_offload:
                        is_sequential_cpu_offload = isinstance(component._hf_hook, AlignDevicesHook)

                    logger.info(
                        "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
                    )
                    remove_hook_from_module(component, recurse=is_sequential_cpu_offload)

i'm not sure if i'm right. i just want to figure out why call load_lora_weights after add AlignDevicesHook to module will set every component's device to meta. and thank you for your patient explanation

zhangvia commented 7 months ago

i find that in pipeline-device-map-auto branch, the pr disable the enable_sequential_cpu_offload() when use device_map=balanced.but i'm still be a little confused why we need call enable_sequential_cpu_offload() in load_lora_weights method when add AlignDevicesHook in models in pipeline

github-actions[bot] commented 6 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

sayakpaul commented 4 months ago

Is this still a problem?

zhangvia commented 4 months ago

Is this still a problem?

given that the device-map=auto can only place the model to different gpus according to the model size. i add AlignDevicesHook to every model manually. but the load_lora_weights method will remove all AlignDevicesHook and never add hooks back which really confused me. i change the hook class name, so the load_lora_weights method will not remove the hook. and until now, my code runs as expected

sayakpaul commented 4 months ago

but the load_lora_weights method will remove all AlignDevicesHook

That is only temporary. We add them back. See here:

https://github.com/huggingface/diffusers/blob/7bfc1ee1b2cb0961ff111f50a9d096816e4dd921/src/diffusers/loaders/lora.py#L517

zhangvia commented 4 months ago
 if is_model_cpu_offload:
    _pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
    _pipeline.enable_sequential_cpu_offload()

it just call the enable_model_cpu_offload() or enable_sequential_cpu_offload(). actually i never call the enable_model_cpu_offload() or enable_sequential_cpu_offload(), my AlignDevicesHook are added to model manually to place models to different gpus. and i think we can't assume that the enable_model_cpu_offload() or enable_sequential_cpu_offload() are called when some model in pipeline has AlignedDevicesHook

sayakpaul commented 4 months ago

it just call the enable_model_cpu_offload() or enable_sequential_cpu_offload().

Those methods are responsible for placing the hooks.

For example: https://github.com/huggingface/diffusers/blob/7bfc1ee1b2cb0961ff111f50a9d096816e4dd921/src/diffusers/pipelines/pipeline_utils.py#L1077

Thing is we are not supposed to call any offloading related utilities manually when any component underlying a pipeline was initialized with "balanced" device_map. This should be sufficiently clear from the errors:

https://github.com/huggingface/diffusers/blob/7bfc1ee1b2cb0961ff111f50a9d096816e4dd921/src/diffusers/pipelines/pipeline_utils.py#L1030

zhangvia commented 4 months ago

Thing is we are not supposed to call any offloading related utilities manually when any component underlying a pipeline was initialized with "balanced" device_map.

i agree with that. but in the previous version, _optionally_disable_offloading() method will return is_sequential_cpu_offload=True because of the AlignDevicesHook when using device_map, which will offload the model to cpu

but i need use more flexible device_map, so i'm still adding hook manually. thank you for your patience ! but i still think we are not supposed to set is_sequential_cpu_offload=True just when the model has AlignDevicesHook. maybe somthing like this?:

is_sequential_cpu_offload = component.device.type=='cpu' and (
                            isinstance(component._hf_hook, AlignDevicesHook)
                            or hasattr(component._hf_hook, "hooks")
                            and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
                        )
sayakpaul commented 4 months ago

Feel free to open a PR and we can take it from there :-)

zhangvia commented 4 months ago

i've just created pr #8750 for this

github-actions[bot] commented 1 month ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.