ljleb / sd-webui-freeu

a1111 implementation of https://github.com/ChenyangSi/FreeU
MIT License
307 stars 16 forks source link

Apple silicon: MPS backend does not have support for that dtype #12

Closed knox closed 1 year ago

knox commented 1 year ago

On Apple silicon it fails with:

      File "/Users/knox/stable-diffusion-webui/extensions/sd-webui-freeu/lib_free_u/unet.py", line 124, in filter_skip
        x_freq = torch.fft.fftn(x.to(dtype=torch.float32), dim=(-2, -1))
    TypeError: Trying to convert ComplexFloat to the MPS backend but it does not have support for that dtype.
ljleb commented 1 year ago

Please share the whole stacktrace, I need to know whether this happen on the first or second decoder block.

knox commented 1 year ago

here you go

    Traceback (most recent call last):
      File "/Users/knox/stable-diffusion-webui/modules/call_queue.py", line 57, in f
        res = list(func(*args, **kwargs))
      File "/Users/knox/stable-diffusion-webui/modules/call_queue.py", line 36, in f
        res = func(*args, **kwargs)
      File "/Users/knox/stable-diffusion-webui/modules/txt2img.py", line 55, in txt2img
        processed = processing.process_images(p)
      File "/Users/knox/stable-diffusion-webui/modules/processing.py", line 732, in process_images
        res = process_images_inner(p)
      File "/Users/knox/stable-diffusion-webui/extensions/sd-webui-controlnet/scripts/batch_hijack.py", line 42, in processing_process_images_hijack
        return getattr(processing, '__controlnet_original_process_images_inner')(p, *args, **kwargs)
      File "/Users/knox/stable-diffusion-webui/modules/processing.py", line 867, in process_images_inner
        samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
      File "/Users/knox/stable-diffusion-webui/modules/processing.py", line 1140, in sample
        samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
      File "/Users/knox/stable-diffusion-webui/modules/sd_samplers_timesteps.py", line 158, in sample
        samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
      File "/Users/knox/stable-diffusion-webui/modules/sd_samplers_common.py", line 261, in launch_sampling
        return func()
      File "/Users/knox/stable-diffusion-webui/modules/sd_samplers_timesteps.py", line 158, in <lambda>
        samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
      File "/Users/knox/stable-diffusion-webui/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
        return func(*args, **kwargs)
      File "/Users/knox/stable-diffusion-webui/modules/sd_samplers_timesteps_impl.py", line 24, in ddim
        e_t = model(x, timesteps[index].item() * s_in, **extra_args)
      File "/Users/knox/stable-diffusion-webui/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
        return forward_call(*args, **kwargs)
      File "/Users/knox/stable-diffusion-webui/modules/sd_samplers_cfg_denoiser.py", line 188, in forward
        x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
      File "/Users/knox/stable-diffusion-webui/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
        return forward_call(*args, **kwargs)
      File "/Users/knox/stable-diffusion-webui/modules/sd_samplers_timesteps.py", line 30, in forward
        return self.inner_model.apply_model(input, timesteps, **kwargs)
      File "/Users/knox/stable-diffusion-webui/modules/sd_hijack_utils.py", line 17, in <lambda>
        setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
      File "/Users/knox/stable-diffusion-webui/modules/sd_hijack_utils.py", line 26, in __call__
        return self.__sub_func(self.__orig_func, *args, **kwargs)
      File "/Users/knox/stable-diffusion-webui/modules/sd_hijack_unet.py", line 48, in apply_model
        return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
      File "/Users/knox/stable-diffusion-webui/repositories/stable-diffusion-stability-ai/ldm/models/diffusion/ddpm.py", line 858, in apply_model
        x_recon = self.model(x_noisy, t, **cond)
      File "/Users/knox/stable-diffusion-webui/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
        return forward_call(*args, **kwargs)
      File "/Users/knox/stable-diffusion-webui/repositories/stable-diffusion-stability-ai/ldm/models/diffusion/ddpm.py", line 1335, in forward
        out = self.diffusion_model(x, t, context=cc)
      File "/Users/knox/stable-diffusion-webui/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
        return forward_call(*args, **kwargs)
      File "/Users/knox/stable-diffusion-webui/modules/sd_unet.py", line 91, in UNetModel_forward
        return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
      File "/Users/knox/stable-diffusion-webui/extensions/sd-webui-freeu/lib_free_u/unet.py", line 48, in forward
        h = free_u_cat(h, hs.pop())
      File "/Users/knox/stable-diffusion-webui/extensions/sd-webui-freeu/lib_free_u/unet.py", line 117, in free_u_cat
        h_skip = filter_skip(h_skip, threshold=1, scale=global_state.skip_factors[index])
      File "/Users/knox/stable-diffusion-webui/extensions/sd-webui-freeu/lib_free_u/unet.py", line 124, in filter_skip
        x_freq = torch.fft.fftn(x.to(dtype=torch.float32), dim=(-2, -1))
    TypeError: Trying to convert ComplexFloat to the MPS backend but it does not have support for that dtype.
ljleb commented 1 year ago

Ah actually it does not give me the info I need. I'll try some fixes and we'll see if that fixes it for you.

ljleb commented 1 year ago

I pushed an update, let me know if the problem persists.

knox commented 1 year ago

thanks for your efforts. with 1151feb the problem has changed to

      File "/Users/knox/stable-diffusion-webui/extensions/sd-webui-freeu/lib_free_u/unet.py", line 50, in filter_skip
        x_freq = torch.fft.fftn(x.float(), dim=(-2, -1))
    TypeError: Trying to convert ComplexFloat to the MPS backend but it does not have support for that dtype.

i was trying to dig into this myself but unfortunatly i dont have much of a clue about python, torch and things.

ljleb commented 1 year ago

Thanks for the update. A workaround is to do the fft on the cpu, although that is pretty slow. amd+directml also needs this apparently.

knox commented 1 year ago

i've spend some more to time to understand the actual problem and here is some documentation of it. most of this you probably know better than me but maybe some pieces may be helpful still.

❯ sysctl -a | grep machdep.cpu.brand_string
machdep.cpu.brand_string: Apple M1 Pro
❯ python -m pip list | egrep "^torch\ "
torch                     2.0.1

Script from here https://github.com/pytorch/pytorch/issues/78168#issuecomment-1137686403 gives the expected results:

Valid Types: [torch.float32, torch.float32, torch.float16, torch.float16, torch.uint8, torch.int8, torch.int16, torch.int16, torch.int32, torch.int32, torch.int64, torch.int64, torch.bool]
Invalid Types: [torch.float64, torch.float64, torch.bfloat16, torch.complex64, torch.complex128, torch.complex128, torch.quint8, torch.qint8, torch.quint4x2]
ljleb commented 1 year ago

I pushed a dirty fix that does the fft on the cpu for mac. Let me know if that fixes it. It should be slower, but it should work.

Valid Types: [torch.float32, torch.float32, torch.float16, torch.float16, torch.uint8, torch.int8, torch.int16, torch.int16, torch.int32, torch.int32, torch.int64, torch.int64, torch.bool] Invalid Types: [torch.float64, torch.float64, torch.bfloat16, torch.complex64, torch.complex128, torch.complex128, torch.quint8, torch.qint8, torch.quint4x2]

Yeah as expected, complex numbers are not available on the mps device. See https://github.com/pytorch/pytorch/issues/89657

ljleb commented 1 year ago

I think a better solution would be to provide an alternative filter option that does not need complex numbers.

knox commented 1 year ago

with 1062ee6 it not longer crashes 🎉

duration for a simple gen goes up from 17.3 to 19.5

a woman in the garden
Negative prompt: disfigured, ugly
Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 3328468048, Size: 512x512, Model hash: 6cb32f354e, Model: aniverse_V13Pruned, VAE hash: f921fb3f29, VAE: vae-ft-mse-840000-ema-pruned.safetensors, Clip skip: 2, FreeU Stages: "[{\"backbone_factor\": 1.2, \"skip_factor\": 0.9}, {\"backbone_factor\": 1.4, \"skip_factor\": 0.2}]"