Open nancygd opened 3 months ago
Could you please post full output from ComfyUI?
I'm having the same issue.
!!! Exception during processing !!! The size of tensor a (640) must match the size of tensor b (320) at non-singleton dimension 1 Traceback (most recent call last): File "/ComfyUI/execution.py", line 313, in execute output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/execution.py", line 188, in get_output_data return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/execution.py", line 165, in map_node_over_list process_inputs(input_dict, i) File "/ComfyUI/execution.py", line 154, in process_inputs results.append(getattr(obj, func)(inputs)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/custom_nodes/efficiency-nodes-comfyui/efficiency_nodes.py", line 2225, in sample_adv return super().sample(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/custom_nodes/efficiency-nodes-comfyui/efficiency_nodes.py", line 732, in sample samples, images, gifs, preview = process_latent_image(model, seed, steps, cfg, sampler_name, scheduler, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/custom_nodes/efficiency-nodes-comfyui/efficiency_nodes.py", line 554, in process_latent_image samples = KSamplerAdvanced().sample(model, add_noise, seed, steps, cfg, sampler_name, scheduler, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/nodes.py", line 1452, in sample return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/nodes.py", line 1385, in common_ksampler samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/custom_nodes/ComfyUI-Impact-Pack/modules/impact/sample_error_enhancer.py", line 22, in informative_sample raise e File "/ComfyUI/custom_nodes/ComfyUI-Impact-Pack/modules/impact/sample_error_enhancer.py", line 9, in informative_sample return original_sample(args, kwargs) # This code helps interpret error messages that occur within exceptions but does not have any impact on other operations. ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/comfy/sample.py", line 43, in sample samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/comfy/samplers.py", line 829, in sample return sample(self.model, noise, positive, negative, cfg, self.device, sampler, sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/custom_nodes/ComfyUI-BrushNet/model_patch.py", line 120, in modified_sample return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/comfy/samplers.py", line 716, in sample output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/comfy/samplers.py", line 695, in inner_sample samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/comfy/samplers.py", line 600, in sample samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, self.extra_options) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/comfy/k_diffusion/sampling.py", line 635, in sample_dpmpp_2m_sde denoised = model(x, sigmas[i] * s_in, extra_args) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/comfy/samplers.py", line 299, in call out = self.inner_model(x, sigma, model_options=model_options, seed=seed) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/comfy/samplers.py", line 682, in call return self.predict_noise(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/comfy/samplers.py", line 685, in predict_noise return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/comfy/samplers.py", line 279, in sampling_function out = calc_cond_batch(model, conds, x, timestep, model_options) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/comfy/samplers.py", line 226, in calc_cond_batch output = model_options['model_function_wrapper'](model.apply_model, {"input": inputx, "timestep": timestep, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/custom_nodes/ComfyUI-BrushNet/model_patch.py", line 52, in brushnet_model_function_wrapper return apply_model_method(x, timestep, options_dict['c']) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/comfy/model_base.py", line 145, in apply_model model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, *extra_conds).float() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/comfy/ldm/modules/diffusionmodules/openaimodel.py", line 852, in forward h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/comfy/ldm/modules/diffusionmodules/openaimodel.py", line 44, in forward_timestep_embed x = layer(x, context, transformer_options) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/ComfyUI/custom_nodes/ComfyUI-BrushNet/brushnet_nodes.py", line 1061, in forward_patched_by_brushnet h += to_add.to(h.dtype).to(h.device) RuntimeError: The size of tensor a (640) must match the size of tensor b (320) at non-singleton dimension 1
I don't know what changed, but it works for me now
I can't reproduce the error as well. May be some commits of ComfyUI are the reason.
I figured it out, I had my ComfyUI launcher script running the argument "--fp8_e4m3fn-unet" for flux.
I can't reproduce the error as well. May be some commits of ComfyUI are the reason.
I don't know what happen at here, but my friend sometime had meet the problem, he can fix the problem when he change checkpoint model. but i can't use the method to deal ,so i don't know what happen ,thank you
What checkpoint do you use? It should be float16
, bfloat16
, float32
, or float64
. Also check ComfyUI startup options.
I have this exact same error! as long as I removed --fast command in launch argument, this error is gone.... But i wish I can have both.. --fast is incredibly power speed up 40 series flux generation by 40% https://github.com/comfyanonymous/ComfyUI/commit/904bf58e7d27eb254d20879e306042653debc4b3
After upgrading to the latest version of Comfyui on October 11th, there is an error message. Returning to the version on October 9th is normal. How to solve this problem?
I also encountered
I upgraded CUDA to version 12.4, and now using the latest version of Comfyui is working properly. Everyone can give it a try.
This issue after upgrade Comyui(https://github.com/comfyanonymous/ComfyUI/commit/e38c94228bce913c1a88f6776f6a21bd64926aec#diff-83920b72a497ff05a33ecf5ac3d19df7911f228f9921fa21e7b64c3b24781fafR101). After that, loading the Flux model no longer requires adding the --fast
parameter in the command line; you can directly select fp8_e4m3fn_fast
. Removing the --fast
parameter resolves this problem.
如果使用绘世可以把这个勾去掉
This issue after upgrade Comyui(comfyanonymous/ComfyUI@e38c942#diff-83920b72a497ff05a33ecf5ac3d19df7911f228f9921fa21e7b64c3b24781fafR101). After that, loading the Flux model no longer requires adding the
--fast
parameter in the command line; you can directly selectfp8_e4m3fn_fast
. Removing the--fast
parameter resolves this problem.如果使用绘世可以把这个勾去掉
This seems to solve the issue but the "--fast" still works faster than the weight dtype "fp8_e4m3fn_fast" in the Load Diffusion Model node. Maybe @nullquant can work something out so we can have his argument --fast and BrushNet at the same time. Appreciate that! Thx for the hardwork
""" This file is part of ComfyUI. Copyright (C) 2024 Stability AI
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch import comfy.model_management from comfy.cli_args import args import comfy.float
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): if input is not None: if dtype is None: dtype = input.dtype if bias_dtype is None: bias_dtype = dtype if device is None: device = input.device
bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if s.bias is not None:
has_function = s.bias_function is not None
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
if has_function:
bias = s.bias_function(bias)
has_function = s.weight_function is not None
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
if has_function:
weight = s.weight_function(weight)
return weight, bias
class CastWeightBiasOp: comfy_cast_weights = False weight_function = None bias_function = None
class disable_weight_init: class Linear(torch.nn.Linear, CastWeightBiasOp): def reset_parameters(self): return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class Conv1d(torch.nn.Conv1d, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class Conv2d(torch.nn.Conv2d, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class Conv3d(torch.nn.Conv3d, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
if self.weight is not None:
weight, bias = cast_bias_weight(self, input)
else:
weight = None
bias = None
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input, output_size=None):
num_spatial_dims = 2
output_padding = self._output_padding(
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose2d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class ConvTranspose1d(torch.nn.ConvTranspose1d, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input, output_size=None):
num_spatial_dims = 1
output_padding = self._output_padding(
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose1d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
def reset_parameters(self):
self.bias = None
return None
def forward_comfy_cast_weights(self, input, out_dtype=None):
output_dtype = out_dtype
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
out_dtype = None
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
if "out_dtype" in kwargs:
kwargs.pop("out_dtype")
return super().forward(*args, **kwargs)
@classmethod
def conv_nd(s, dims, *args, **kwargs):
if dims == 2:
return s.Conv2d(*args, **kwargs)
elif dims == 3:
return s.Conv3d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")
class manual_cast(disable_weight_init): class Linear(disable_weight_init.Linear): comfy_cast_weights = True
class Conv1d(disable_weight_init.Conv1d):
comfy_cast_weights = True
class Conv2d(disable_weight_init.Conv2d):
comfy_cast_weights = True
class Conv3d(disable_weight_init.Conv3d):
comfy_cast_weights = True
class GroupNorm(disable_weight_init.GroupNorm):
comfy_cast_weights = True
class LayerNorm(disable_weight_init.LayerNorm):
comfy_cast_weights = True
class ConvTranspose2d(disable_weight_init.ConvTranspose2d):
comfy_cast_weights = True
class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
comfy_cast_weights = True
class Embedding(disable_weight_init.Embedding):
comfy_cast_weights = True
def fp8_linear(self, input): dtype = self.weight.dtype if dtype not in [torch.float8_e4m3fn]: return None
tensor_2d = False
if len(input.shape) == 2:
tensor_2d = True
input = input.unsqueeze(1)
if len(input.shape) == 3:
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
w = w.t()
scale_weight = self.scale_weight
scale_input = self.scale_input
if scale_weight is None:
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
else:
scale_weight = scale_weight.to(input.device)
if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
inn = input.reshape(-1, input.shape[2]).to(dtype)
else:
scale_input = scale_input.to(input.device)
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)
if bias is not None:
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
else:
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)
if isinstance(o, tuple):
o = o[0]
if tensor_2d:
return o.reshape(input.shape[0], -1)
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
return None
class fp8_ops(manual_cast): class Linear(manual_cast.Linear): def reset_parameters(self): self.scale_weight = None self.scale_input = None return None
def forward_comfy_cast_weights(self, input):
out = fp8_linear(self, input)
if out is not None:
return out
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None): class scaled_fp8_op(manual_cast): class Linear(manual_cast.Linear): def init(self, *args, *kwargs): if override_dtype is not None: kwargs['dtype'] = override_dtype super().init(args, **kwargs)
def reset_parameters(self):
if not hasattr(self, 'scale_weight'):
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
if not scale_input:
self.scale_input = None
if not hasattr(self, 'scale_input'):
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
return None
def forward_comfy_cast_weights(self, input):
if fp8_matrix_mult:
out = fp8_linear(self, input)
if out is not None:
return out
weight, bias = cast_bias_weight(self, input)
if weight.numel() < input.numel(): #TODO: optimize
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
else:
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
def convert_weight(self, weight, inplace=False, **kwargs):
if inplace:
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
return weight
else:
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
if inplace_update:
self.weight.data.copy_(weight)
else:
self.weight = torch.nn.Parameter(weight, requires_grad=False)
return scaled_fp8_op
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8)
if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8:
return fp8_ops
if compute_dtype is None or weight_dtype == compute_dtype:
return disable_weight_init
return manual_cast
hello, when i run until KSampler, there is a error, do you know how to deal with it? thank you!