nullquant / ComfyUI-BrushNet

ComfyUI BrushNet nodes
Apache License 2.0
642 stars 24 forks source link

RuntimeError: The size of tensor a (640) must match the size of tensor b (320) at non-singleton dimension 1 #136

Open nancygd opened 3 months ago

nancygd commented 3 months ago

hello, when i run until KSampler, there is a error, do you know how to deal with it? thank you!

nullquant commented 3 months ago

Could you please post full output from ComfyUI?

Thater commented 3 months ago

I'm having the same issue.

!!! Exception during processing !!! The size of tensor a (640) must match the size of tensor b (320) at non-singleton dimension 1 Traceback (most recent call last): File "/ComfyUI/execution.py", line 313, in execute output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/execution.py", line 188, in get_output_data return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/execution.py", line 165, in map_node_over_list process_inputs(input_dict, i) File "/ComfyUI/execution.py", line 154, in process_inputs results.append(getattr(obj, func)(inputs)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/custom_nodes/efficiency-nodes-comfyui/efficiency_nodes.py", line 2225, in sample_adv return super().sample(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/custom_nodes/efficiency-nodes-comfyui/efficiency_nodes.py", line 732, in sample samples, images, gifs, preview = process_latent_image(model, seed, steps, cfg, sampler_name, scheduler, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/custom_nodes/efficiency-nodes-comfyui/efficiency_nodes.py", line 554, in process_latent_image samples = KSamplerAdvanced().sample(model, add_noise, seed, steps, cfg, sampler_name, scheduler, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/nodes.py", line 1452, in sample return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/nodes.py", line 1385, in common_ksampler samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/custom_nodes/ComfyUI-Impact-Pack/modules/impact/sample_error_enhancer.py", line 22, in informative_sample raise e File "/ComfyUI/custom_nodes/ComfyUI-Impact-Pack/modules/impact/sample_error_enhancer.py", line 9, in informative_sample return original_sample(args, kwargs) # This code helps interpret error messages that occur within exceptions but does not have any impact on other operations. ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/comfy/sample.py", line 43, in sample samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/comfy/samplers.py", line 829, in sample return sample(self.model, noise, positive, negative, cfg, self.device, sampler, sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/custom_nodes/ComfyUI-BrushNet/model_patch.py", line 120, in modified_sample return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/comfy/samplers.py", line 716, in sample output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/comfy/samplers.py", line 695, in inner_sample samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/comfy/samplers.py", line 600, in sample samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, self.extra_options) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/comfy/k_diffusion/sampling.py", line 635, in sample_dpmpp_2m_sde denoised = model(x, sigmas[i] * s_in, extra_args) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/comfy/samplers.py", line 299, in call out = self.inner_model(x, sigma, model_options=model_options, seed=seed) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/comfy/samplers.py", line 682, in call return self.predict_noise(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/comfy/samplers.py", line 685, in predict_noise return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/comfy/samplers.py", line 279, in sampling_function out = calc_cond_batch(model, conds, x, timestep, model_options) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/comfy/samplers.py", line 226, in calc_cond_batch output = model_options['model_function_wrapper'](model.apply_model, {"input": inputx, "timestep": timestep, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/custom_nodes/ComfyUI-BrushNet/model_patch.py", line 52, in brushnet_model_function_wrapper return apply_model_method(x, timestep, options_dict['c']) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/comfy/model_base.py", line 145, in apply_model model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, *extra_conds).float() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/comfy/ldm/modules/diffusionmodules/openaimodel.py", line 852, in forward h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/comfy/ldm/modules/diffusionmodules/openaimodel.py", line 44, in forward_timestep_embed x = layer(x, context, transformer_options) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/custom_nodes/ComfyUI-BrushNet/brushnet_nodes.py", line 1061, in forward_patched_by_brushnet h += to_add.to(h.dtype).to(h.device) RuntimeError: The size of tensor a (640) must match the size of tensor b (320) at non-singleton dimension 1

Thater commented 2 months ago

I don't know what changed, but it works for me now

nullquant commented 2 months ago

I can't reproduce the error as well. May be some commits of ComfyUI are the reason.

Thater commented 2 months ago

I figured it out, I had my ComfyUI launcher script running the argument "--fp8_e4m3fn-unet" for flux.

nancygd commented 2 months ago

I can't reproduce the error as well. May be some commits of ComfyUI are the reason.

I don't know what happen at here, but my friend sometime had meet the problem, he can fix the problem when he change checkpoint model. but i can't use the method to deal ,so i don't know what happen ,thank you

nullquant commented 2 months ago

What checkpoint do you use? It should be float16, bfloat16, float32, or float64. Also check ComfyUI startup options.

LiJT commented 1 month ago

I have this exact same error! as long as I removed --fast command in launch argument, this error is gone.... But i wish I can have both.. --fast is incredibly power speed up 40 series flux generation by 40% https://github.com/comfyanonymous/ComfyUI/commit/904bf58e7d27eb254d20879e306042653debc4b3

cjc999 commented 1 month ago

After upgrading to the latest version of Comfyui on October 11th, there is an error message. Returning to the version on October 9th is normal. How to solve this problem?

zhiyulee3 commented 1 month ago

I also encountered

cjc999 commented 1 month ago

I upgraded CUDA to version 12.4, and now using the latest version of Comfyui is working properly. Everyone can give it a try.

Anson2048 commented 1 month ago

This issue after upgrade Comyui(https://github.com/comfyanonymous/ComfyUI/commit/e38c94228bce913c1a88f6776f6a21bd64926aec#diff-83920b72a497ff05a33ecf5ac3d19df7911f228f9921fa21e7b64c3b24781fafR101). After that, loading the Flux model no longer requires adding the --fast parameter in the command line; you can directly select fp8_e4m3fn_fast. Removing the --fastparameter resolves this problem.

如果使用绘世可以把这个勾去掉 image

Orenji-Tangerine commented 4 days ago

This issue after upgrade Comyui(comfyanonymous/ComfyUI@e38c942#diff-83920b72a497ff05a33ecf5ac3d19df7911f228f9921fa21e7b64c3b24781fafR101). After that, loading the Flux model no longer requires adding the --fast parameter in the command line; you can directly select fp8_e4m3fn_fast. Removing the --fastparameter resolves this problem.

如果使用绘世可以把这个勾去掉 image

This seems to solve the issue but the "--fast" still works faster than the weight dtype "fp8_e4m3fn_fast" in the Load Diffusion Model node. Maybe @nullquant can work something out so we can have his argument --fast and BrushNet at the same time. Appreciate that! Thx for the hardwork

kanxun88 commented 3 days ago

""" This file is part of ComfyUI. Copyright (C) 2024 Stability AI

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <https://www.gnu.org/licenses/>.

"""

import torch import comfy.model_management from comfy.cli_args import args import comfy.float

cast_to = comfy.model_management.cast_to #TODO: remove once no more references

def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)

def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): if input is not None: if dtype is None: dtype = input.dtype if bias_dtype is None: bias_dtype = dtype if device is None: device = input.device

bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if s.bias is not None:
    has_function = s.bias_function is not None
    bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
    if has_function:
        bias = s.bias_function(bias)

has_function = s.weight_function is not None
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
if has_function:
    weight = s.weight_function(weight)
return weight, bias

class CastWeightBiasOp: comfy_cast_weights = False weight_function = None bias_function = None

class disable_weight_init: class Linear(torch.nn.Linear, CastWeightBiasOp): def reset_parameters(self): return None

    def forward_comfy_cast_weights(self, input):
        weight, bias = cast_bias_weight(self, input)
        return torch.nn.functional.linear(input, weight, bias)

    def forward(self, *args, **kwargs):
        if self.comfy_cast_weights:
            return self.forward_comfy_cast_weights(*args, **kwargs)
        else:
            return super().forward(*args, **kwargs)

class Conv1d(torch.nn.Conv1d, CastWeightBiasOp):
    def reset_parameters(self):
        return None

    def forward_comfy_cast_weights(self, input):
        weight, bias = cast_bias_weight(self, input)
        return self._conv_forward(input, weight, bias)

    def forward(self, *args, **kwargs):
        if self.comfy_cast_weights:
            return self.forward_comfy_cast_weights(*args, **kwargs)
        else:
            return super().forward(*args, **kwargs)

class Conv2d(torch.nn.Conv2d, CastWeightBiasOp):
    def reset_parameters(self):
        return None

    def forward_comfy_cast_weights(self, input):
        weight, bias = cast_bias_weight(self, input)
        return self._conv_forward(input, weight, bias)

    def forward(self, *args, **kwargs):
        if self.comfy_cast_weights:
            return self.forward_comfy_cast_weights(*args, **kwargs)
        else:
            return super().forward(*args, **kwargs)

class Conv3d(torch.nn.Conv3d, CastWeightBiasOp):
    def reset_parameters(self):
        return None

    def forward_comfy_cast_weights(self, input):
        weight, bias = cast_bias_weight(self, input)
        return self._conv_forward(input, weight, bias)

    def forward(self, *args, **kwargs):
        if self.comfy_cast_weights:
            return self.forward_comfy_cast_weights(*args, **kwargs)
        else:
            return super().forward(*args, **kwargs)

class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp):
    def reset_parameters(self):
        return None

    def forward_comfy_cast_weights(self, input):
        weight, bias = cast_bias_weight(self, input)
        return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)

    def forward(self, *args, **kwargs):
        if self.comfy_cast_weights:
            return self.forward_comfy_cast_weights(*args, **kwargs)
        else:
            return super().forward(*args, **kwargs)

class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
    def reset_parameters(self):
        return None

    def forward_comfy_cast_weights(self, input):
        if self.weight is not None:
            weight, bias = cast_bias_weight(self, input)
        else:
            weight = None
            bias = None
        return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)

    def forward(self, *args, **kwargs):
        if self.comfy_cast_weights:
            return self.forward_comfy_cast_weights(*args, **kwargs)
        else:
            return super().forward(*args, **kwargs)

class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
    def reset_parameters(self):
        return None

    def forward_comfy_cast_weights(self, input, output_size=None):
        num_spatial_dims = 2
        output_padding = self._output_padding(
            input, output_size, self.stride, self.padding, self.kernel_size,
            num_spatial_dims, self.dilation)

        weight, bias = cast_bias_weight(self, input)
        return torch.nn.functional.conv_transpose2d(
            input, weight, bias, self.stride, self.padding,
            output_padding, self.groups, self.dilation)

    def forward(self, *args, **kwargs):
        if self.comfy_cast_weights:
            return self.forward_comfy_cast_weights(*args, **kwargs)
        else:
            return super().forward(*args, **kwargs)

class ConvTranspose1d(torch.nn.ConvTranspose1d, CastWeightBiasOp):
    def reset_parameters(self):
        return None

    def forward_comfy_cast_weights(self, input, output_size=None):
        num_spatial_dims = 1
        output_padding = self._output_padding(
            input, output_size, self.stride, self.padding, self.kernel_size,
            num_spatial_dims, self.dilation)

        weight, bias = cast_bias_weight(self, input)
        return torch.nn.functional.conv_transpose1d(
            input, weight, bias, self.stride, self.padding,
            output_padding, self.groups, self.dilation)

    def forward(self, *args, **kwargs):
        if self.comfy_cast_weights:
            return self.forward_comfy_cast_weights(*args, **kwargs)
        else:
            return super().forward(*args, **kwargs)

class Embedding(torch.nn.Embedding, CastWeightBiasOp):
    def reset_parameters(self):
        self.bias = None
        return None

    def forward_comfy_cast_weights(self, input, out_dtype=None):
        output_dtype = out_dtype
        if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
            out_dtype = None
        weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
        return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)

    def forward(self, *args, **kwargs):
        if self.comfy_cast_weights:
            return self.forward_comfy_cast_weights(*args, **kwargs)
        else:
            if "out_dtype" in kwargs:
                kwargs.pop("out_dtype")
            return super().forward(*args, **kwargs)

@classmethod
def conv_nd(s, dims, *args, **kwargs):
    if dims == 2:
        return s.Conv2d(*args, **kwargs)
    elif dims == 3:
        return s.Conv3d(*args, **kwargs)
    else:
        raise ValueError(f"unsupported dimensions: {dims}")

class manual_cast(disable_weight_init): class Linear(disable_weight_init.Linear): comfy_cast_weights = True

class Conv1d(disable_weight_init.Conv1d):
    comfy_cast_weights = True

class Conv2d(disable_weight_init.Conv2d):
    comfy_cast_weights = True

class Conv3d(disable_weight_init.Conv3d):
    comfy_cast_weights = True

class GroupNorm(disable_weight_init.GroupNorm):
    comfy_cast_weights = True

class LayerNorm(disable_weight_init.LayerNorm):
    comfy_cast_weights = True

class ConvTranspose2d(disable_weight_init.ConvTranspose2d):
    comfy_cast_weights = True

class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
    comfy_cast_weights = True

class Embedding(disable_weight_init.Embedding):
    comfy_cast_weights = True

def fp8_linear(self, input): dtype = self.weight.dtype if dtype not in [torch.float8_e4m3fn]: return None

tensor_2d = False
if len(input.shape) == 2:
    tensor_2d = True
    input = input.unsqueeze(1)

if len(input.shape) == 3:
    w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
    w = w.t()

    scale_weight = self.scale_weight
    scale_input = self.scale_input
    if scale_weight is None:
        scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
    else:
        scale_weight = scale_weight.to(input.device)

    if scale_input is None:
        scale_input = torch.ones((), device=input.device, dtype=torch.float32)
        inn = input.reshape(-1, input.shape[2]).to(dtype)
    else:
        scale_input = scale_input.to(input.device)
        inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)

    if bias is not None:
        o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
    else:
        o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)

    if isinstance(o, tuple):
        o = o[0]

    if tensor_2d:
        return o.reshape(input.shape[0], -1)

    return o.reshape((-1, input.shape[1], self.weight.shape[0]))

return None

class fp8_ops(manual_cast): class Linear(manual_cast.Linear): def reset_parameters(self): self.scale_weight = None self.scale_input = None return None

    def forward_comfy_cast_weights(self, input):
        out = fp8_linear(self, input)
        if out is not None:
            return out

        weight, bias = cast_bias_weight(self, input)
        return torch.nn.functional.linear(input, weight, bias)

def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None): class scaled_fp8_op(manual_cast): class Linear(manual_cast.Linear): def init(self, *args, *kwargs): if override_dtype is not None: kwargs['dtype'] = override_dtype super().init(args, **kwargs)

        def reset_parameters(self):
            if not hasattr(self, 'scale_weight'):
                self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)

            if not scale_input:
                self.scale_input = None

            if not hasattr(self, 'scale_input'):
                self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
            return None

        def forward_comfy_cast_weights(self, input):
            if fp8_matrix_mult:
                out = fp8_linear(self, input)
                if out is not None:
                    return out

            weight, bias = cast_bias_weight(self, input)

            if weight.numel() < input.numel(): #TODO: optimize
                return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
            else:
                return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)

        def convert_weight(self, weight, inplace=False, **kwargs):
            if inplace:
                weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
                return weight
            else:
                return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)

        def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
            weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
            if inplace_update:
                self.weight.data.copy_(weight)
            else:
                self.weight = torch.nn.Parameter(weight, requires_grad=False)

return scaled_fp8_op

def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8)

if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8:
    return fp8_ops

if compute_dtype is None or weight_dtype == compute_dtype:
    return disable_weight_init

return manual_cast