pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.16k stars 22.09k forks source link

Add support for float8 dtypes for the MPS backend #132624

Open lisiyizu opened 1 month ago

lisiyizu commented 1 month ago

šŸ› Describe the bug


model_type FLOW
clip missing: ['text_projection.weight']
Requested to load FluxClipModel_
Loading 1 new model
Requested to load Flux
Loading 1 new model
  0%|                                                                                                             | 0/4 [00:03<?, ?it/s]
!!! Exception during processing!!! Trying to convert Float8_e5m2 to the MPS backend but it does not have support for that dtype.
Traceback (most recent call last):
  File "/Users/kummy/Documents/work/ComfyUI/execution.py", line 152, in recursive_execute
    output_data, output_ui = get_output_data(obj, input_data_all)
  File "/Users/kummy/Documents/work/ComfyUI/execution.py", line 82, in get_output_data
    return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
  File "/Users/kummy/Documents/work/ComfyUI/execution.py", line 75, in map_node_over_list
    results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
  File "/Users/kummy/Documents/work/ComfyUI/comfy_extras/nodes_custom_sampler.py", line 612, in sample
    samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed)
  File "/Users/kummy/Documents/work/ComfyUI/comfy/samplers.py", line 716, in sample
    output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
  File "/Users/kummy/Documents/work/ComfyUI/comfy/samplers.py", line 695, in inner_sample
    samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
  File "/Users/kummy/Documents/work/ComfyUI/comfy/samplers.py", line 600, in sample
    samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
  File "/Users/kummy/miniconda3/envs/comflowy/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/Users/kummy/Documents/work/ComfyUI/comfy/k_diffusion/sampling.py", line 143, in sample_euler
    denoised = model(x, sigma_hat * s_in, **extra_args)
  File "/Users/kummy/Documents/work/ComfyUI/comfy/samplers.py", line 299, in __call__
    out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
  File "/Users/kummy/Documents/work/ComfyUI/comfy/samplers.py", line 682, in __call__
    return self.predict_noise(*args, **kwargs)
  File "/Users/kummy/Documents/work/ComfyUI/comfy/samplers.py", line 685, in predict_noise
    return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
  File "/Users/kummy/Documents/work/ComfyUI/comfy/samplers.py", line 279, in sampling_function
    out = calc_cond_batch(model, conds, x, timestep, model_options)
  File "/Users/kummy/Documents/work/ComfyUI/comfy/samplers.py", line 228, in calc_cond_batch
    output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
  File "/Users/kummy/Documents/work/ComfyUI/comfy/model_base.py", line 122, in apply_model
    model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
  File "/Users/kummy/miniconda3/envs/comflowy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/kummy/miniconda3/envs/comflowy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/kummy/Documents/work/ComfyUI/comfy/ldm/flux/model.py", line 143, in forward
    out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance)
  File "/Users/kummy/Documents/work/ComfyUI/comfy/ldm/flux/model.py", line 101, in forward_orig
    img = self.img_in(img)
  File "/Users/kummy/miniconda3/envs/comflowy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/kummy/miniconda3/envs/comflowy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/kummy/Documents/work/ComfyUI/comfy/ops.py", line 63, in forward
    return self.forward_comfy_cast_weights(*args, **kwargs)
  File "/Users/kummy/Documents/work/ComfyUI/comfy/ops.py", line 58, in forward_comfy_cast_weights
    weight, bias = cast_bias_weight(self, input)
  File "/Users/kummy/Documents/work/ComfyUI/comfy/ops.py", line 39, in cast_bias_weight
    bias = cast_to(s.bias, dtype, device, non_blocking=non_blocking)
  File "/Users/kummy/Documents/work/ComfyUI/comfy/ops.py", line 24, in cast_to
    return weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
TypeError: Trying to convert Float8_e5m2 to the MPS backend but it does not have support for that dtype.

Prompt executed in 148.34 seconds

Versions

Collecting environment information... PyTorch version: 2.4.0 Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A

OS: macOS 14.5 (arm64) GCC version: Could not collect Clang version: 15.0.0 (clang-1500.3.9.4) CMake version: Could not collect Libc version: N/A

Python version: 3.10.8 (main, Nov 24 2022, 08:08:27) [Clang 14.0.6 ] (64-bit runtime) Python platform: macOS-14.5-arm64-arm-64bit Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Apple M2 Pro

Versions of relevant libraries: [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.26.4 [pip3] onnx==1.16.1 [pip3] onnxruntime==1.15.1 [pip3] open_clip_torch==2.26.1 [pip3] pytorch-lightning==2.3.3 [pip3] torch==2.4.0 [pip3] torchaudio==2.4.0 [pip3] torchmetrics==1.4.0.post0 [pip3] torchsde==0.2.6 [pip3] torchvision==0.19.0 [conda] blas 1.0 mkl
[conda] mkl 2023.1.0 h8e150cf_43560
[conda] mkl-service 2.4.0 py312h6c40b1e_1
[conda] mkl_fft 1.3.8 py312h6c40b1e_0
[conda] mkl_random 1.2.4 py312ha357a0b_0
[conda] numpy 1.26.4 py312hac873b0_0
[conda] numpy-base 1.26.4 py312h6f81483_0
[conda] pytorch 2.2.2 py3.12_0 pytorch [conda] torchaudio 2.2.2 py312_cpu pytorch [conda] torchvision 0.17.2 py312_cpu pytorch

cc @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen @yanbing-j @vkuzo @kadeng @penguinwu

Vargol commented 1 month ago

I you're trying to get Flux working on MPS you'll need to figure out why it's broken (noisy images) on PyTorch 2.4. but works with PyTorch 2.3.1 as well as getting fp8 support

vkuzo commented 1 month ago

Having float8 dtypes defined on MPS would be nice, we'd welcome community contributions on this!