instantX-research / InstantID

InstantID: Zero-shot Identity-Preserving Generation in Seconds 🔥
https://instantid.github.io/
Apache License 2.0
11.12k stars 807 forks source link

`pipe.enable_model_cpu_offload()` fails #121

Closed wjkoh closed 9 months ago

wjkoh commented 9 months ago

I encountered a RuntimeError: "addmm_impl_cpu_" not implemented for 'Half' when modifying the infer.py file in the following manner:

pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(base_model_path, torch_dtype=torch.float16, controlnet=controlnet)
pipe.load_ip_adapter_instantid(face_adapter)
pipe.enable_model_cpu_offload()

The call stack is as follows:

Traceback (most recent call last):
  File "/home/wjkoh/bake-diffusers/InstantID/infer.py", line 92, in <module>
    image = pipe(
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/wjkoh/bake-diffusers/InstantID/pipeline_stable_diffusion_xl_instantid.py", line 523, in __call__
    prompt_image_emb = self._encode_prompt_image_emb(image_embeds, 
  File "/home/wjkoh/bake-diffusers/InstantID/pipeline_stable_diffusion_xl_instantid.py", line 236, in _encode_prompt_image_emb
    prompt_image_emb = self.image_proj_model(prompt_image_emb)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/wjkoh/bake-diffusers/InstantID/ip_adapter/resampler.py", line 114, in forward
    x = self.proj_in(x)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
ResearcherXman commented 9 months ago

You should use torch.float32

wjkoh commented 9 months ago

Thanks for the feedback. It actually works with torch.float32! However, it kind of defeats the purpose of reducing VRAM usage since we have to switch from float16 to float32 when enabling model CPU offload.