shiimizu / ComfyUI-TiledDiffusion

Tiled Diffusion, MultiDiffusion, Mixture of Diffusers, and optimized VAE
294 stars 21 forks source link

Tiled VAE RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! #18

Open yeungmozhu opened 6 months ago

yeungmozhu commented 6 months ago

I use IP-Adapter in my workflow. when running to the Tiled VAE node, it displays the RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! But when I replace it with the original VAE decoding node, the problem is solved.

shiimizu commented 6 months ago

Do you have a workflow I can test?

TheSloppiestOfJoes commented 1 month ago

Just an ugly hack for this issue until it is fixed, add these lines to tiled_vae.py at line 325:

Before:

out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None, training=False, momentum=0, eps=eps)
out = out.view(b, c, *input.size()[2:])

if weight is not None:
    out *= weight.view(1, -1, 1, 1)
if bias is not None:
    out += bias.view(1, -1, 1, 1)
return out

After:

out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None, training=False, momentum=0, eps=eps)
out = out.view(b, c, *input.size()[2:])

if out.device != 'privateuseone:0':
    out = out.to('privateuseone:0')
if weight.device != 'privateuseone:0':
    weight = weight.to('privateuseone:0')
if bias.device != 'privateuseone:0':
    bias = bias.to('privateuseone:0')

if weight is not None:
    out *= weight.view(1, -1, 1, 1)
if bias is not None:
    out += bias.view(1, -1, 1, 1)
return out

Just replace privateuseone:0 with whatever device you're using (probably 'cuda:0' in your case would work, but I cant test the syntax with my AMD card)

This isn't the greatest fix since tensors will be moved extra times and will cause a slight slowdown, but it should make the script usable in the meantime.

GrunclePug commented 1 month ago

+1 having this issue. I have 2 workflows, one uses sdxl the whole time, the other uses flux for the base image and sdxl for the post processing. When i use the non flux version it works fine, but the flux version fails after adding tiled diffusion to sdxl model and trying to decode the result. I've included my workflow, where i have both versions on the far right side workflow_grunclepug.json

GrunclePug commented 1 month ago

Just an ugly hack for this issue until it is fixed, add these lines to tiled_vae.py at line 325:

Before:

out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None, training=False, momentum=0, eps=eps)
out = out.view(b, c, *input.size()[2:])

if weight is not None:
    out *= weight.view(1, -1, 1, 1)
if bias is not None:
    out += bias.view(1, -1, 1, 1)
return out

After:

out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None, training=False, momentum=0, eps=eps)
out = out.view(b, c, *input.size()[2:])

if out.device != 'privateuseone:0':
    out = out.to('privateuseone:0')
if weight.device != 'privateuseone:0':
    weight = weight.to('privateuseone:0')
if bias.device != 'privateuseone:0':
    bias = bias.to('privateuseone:0')

if weight is not None:
    out *= weight.view(1, -1, 1, 1)
if bias is not None:
    out += bias.view(1, -1, 1, 1)
return out

Just replace privateuseone:0 with whatever device you're using (probably 'cuda:0' in your case would work, but I cant test the syntax with my AMD card)

This isn't the greatest fix since tensors will be moved extra times and will cause a slight slowdown, but it should make the script usable in the meantime.

Thank you for the workaround, i just tried it and will use this and made note of the modified source code until a fix releases.

GrunclePug commented 1 month ago

nvm, it got me past that ksampler but failed to encode later on after an upscale and a resize (which both worked) with below error:

[Tiled VAE]: input_size: torch.Size([1, 3, 4248, 2816]), tile_size: 1024, padding: 32
[Tiled VAE]: split to 5x3 = 15 tiles. Optimal tile size 928x864, original tile size 1024x1024
[Tiled VAE]: Executing Encoder Task Queue:   1%|        | 18/1365 [00:02<02:21,  9.49it/s]./start.sh: line 2:  2109 Killed                  python main.py --listen