Open rootonchair opened 1 month ago
Hi @WyattAutomation
...python3.10/site-packages/onediff/infer_compiler/backends/oneflow/transform/builtin_transform.py:221 - convert <class 'list'> failed: Transform failed of <class 'list'>: Transform failed of <class 'diffusers.models.attention_processor.Attention'>: Transform failed of <class 'layer_diffuse.models.attention_processors.AttentionSharingProcessor2_0'>: Unsupported type: <class 'list'>
I think is regarding to this part (https://github.com/rootonchair/diffuser_layerdiffuse/blob/main/layer_diffuse/models/attention_processors.py#L502) in the code which the use of python native list and cannot be supported in the compilation. Maybe you can remove it and retry?
Hi @WyattAutomation
...python3.10/site-packages/onediff/infer_compiler/backends/oneflow/transform/builtin_transform.py:221 - convert <class 'list'> failed: Transform failed of <class 'list'>: Transform failed of <class 'diffusers.models.attention_processor.Attention'>: Transform failed of <class 'layer_diffuse.models.attention_processors.AttentionSharingProcessor2_0'>: Unsupported type: <class 'list'>
I think is regarding to this part (https://github.com/rootonchair/diffuser_layerdiffuse/blob/main/layer_diffuse/models/attention_processors.py#L502) in the code which the use of python native list and cannot be supported in the compilation. Maybe you can remove it and retry?
I will try this and let you know if it works, thankyou for the suggestion
I have OneDiff now working in your basic SD 1.5 example!
I removed lines 106, 252, 502, 664 in layer_diffuse/models/attention_processors.py, but I also had to raise the default recursion limit via sys.setrecursionlimit(1000000) in my top level test script I was using and it worked.
The recursion limit removal may be of concern to some people, I am not sure. I can look into alternatives to fix that but honestly this has me working for now to keep making progress on other development work. Let me know if you push a proper fix/solution for that, I am just using sys.setrecursionlimit(1000000) for now.
just remove or comment-out all four lines in attention_processors.py that have:
self.original_module = [module]
and also add:
sys.setrecursionlimit(1000000)
to the main script that's used to inference the diffusers pipeline with the loaded layerdiffuse models
The performance gained also appears to be working great per the output:
without onediff/oneflow -- ~5s:
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:05<00:00, 9.42it/s]
with onediff/oneflow -- ~2s:
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:02<00:00, 20.81it/s]
The code I used test onediff/oneflow code is:
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import torch
from diffusers import StableDiffusionPipeline
from layer_diffuse.models import TransparentVAEDecoder
from layer_diffuse.loaders import load_lora_to_unet
# import oneflow_compile
from onediff.infer_compiler import oneflow_compile
# workaround for recursion error
import sys
sys.setrecursionlimit(1000000)
model_path = hf_hub_download(
'LayerDiffusion/layerdiffusion-v1',
'layer_sd15_vae_transparent_decoder.safetensors',
)
vae_transparent_decoder = TransparentVAEDecoder.from_pretrained("digiplay/Juggernaut_final", subfolder="vae", torch_dtype=torch.float16).to("cuda")
vae_transparent_decoder.set_transparent_decoder(load_file(model_path))
pipeline = StableDiffusionPipeline.from_pretrained("digiplay/Juggernaut_final", vae=vae_transparent_decoder, torch_dtype=torch.float16, safety_checker=None).to("cuda")
model_path = hf_hub_download(
'LayerDiffusion/layerdiffusion-v1',
'layer_sd15_transparent_attn.safetensors'
)
load_lora_to_unet(pipeline.unet, model_path, frames=1)
pipeline.to("cuda")
# oneflow_compile the unet and vae decoder
pipeline.unet = oneflow_compile(pipeline.unet)
pipeline.vae.decoder = oneflow_compile(pipeline.vae.decoder)
prompt="Enraged kodiak bear with pterodactyl wings throwing a molotov cocktail riding a skateboard, on fire, high quality"
# do an initial generation to kick off the model compilation
image = pipeline(
prompt=prompt,
width=512, height=512,
num_images_per_prompt=1, return_dict=False)[0]
image[0].save("onediff_test_with_compile.png")
#do a second generation to see the actual speed of the compiled inference
image2 = pipeline(
prompt=prompt,
width=512, height=512,
num_images_per_prompt=1, return_dict=False)[0]
image2[0].save("onediff_test_with_compile_fullspeed.png")
...Thanks a ton!
So, after some additional evaluation, I have one remaining blocker to get resolved before getting this running in realtime -- hopefully it's something that can be adreased easily:
For that realtime SD 1.5+ControlNet videogame world demo I shared in a comment here: https://github.com/lllyasviel/LayerDiffuse/issues/32#issuecomment-2351771297
Whenever I run my pipeline, everything compiles and it runs and starts generating images.
However
Whenever the TransparentVaeDecoder is used in my pipeline, every frame generated gets interupted by a progress bar in the terminal that does 8 steps of something and takes about 2 whole seconds to complete (I am assuming this is the transparent VAE decoder, doing what it does )
Even when I set "disable_progress_bar" on my pipeline object prior to Onediff compiling the vae.decoder/unet/controlnet, That progress bar still shows up in the terminal, it does 8 steps of something while hanging up each frame for 2 seconds at a time until it finishes, and does it again for every single frame.
Is it possible to make it so that the Transparent VAE Decoder doesn't require 8 whole steps to be done on every single image? 1 would be Ideal,
The most I ever use is 3 steps but that is for UNet denoising. In that realtime app I use a DMD SD 1.5 distilled Unet made for 1 to 4 steps, DreamShaper7_LCM as the base model in the pipeline, TinyVAE (taesd), and the LCM scheduler, all of which help speed up my pipeline.
without layerdiffuse, the app can generate a frame in 1 to 4 steps, at about 9 FPS to 13 FPS when using 2 or more ControlNets. Roughly 11-12 FPS is the lowest framerate acceptable, and I wanted to use multiple seperate diffusers pipelines inside of seperate threads as a basic producer/consumer architecture for compositing images, and get performace closer to 20FPS. I also want to make use of LayerDiffusion to really enhance the quality of the existing app.
Thanks again for your help here, I feel like we are about to have it working in realtime here, just one last item to address.
Again, the summary is:
How to remove or reduce the 8 steps that happen when using TransparentVaeDecoder as the vae
It is due to this function https://github.com/rootonchair/diffuser_layerdiffuse/blob/main/layer_diffuse/models/modules.py#L239
You could try using estimate_single_pass
only
It is due to this function https://github.com/rootonchair/diffuser_layerdiffuse/blob/main/layer_diffuse/models/modules.py#L239 You could try using
estimate_single_pass
only
brilliant -- I will try this out and let you know how it goes
Progress! I got it working -- only at ~5 FPS at the moment. This appears to be the correct flow of the pipeline for all of this though, so OneDiff/Oneflow compatibility (this isssue) I believe you have solved.
https://github.com/user-attachments/assets/a908921a-1323-495c-90fb-d5a54a5eab8f
I replaced the use of estimate_augmented in the decode method of TransparentVAEDecoder here: https://github.com/rootonchair/diffuser_layerdiffuse/blob/00b5ffaf236ae67992fba57d6984a1b7410d0340/layer_diffuse/models/modules.py#L281 with:
y = self.estimate_single_pass(pixel[i:i+1], z[i:i+1])
and that solved it!
The only thing I think I need now is to be able to use TinyVAE as the input VAE to TransparentVAEDecoder -- if this is feasible. I believe TinyVAE ('madebyollin/taesd') was contributing a lot to performance, so losing that here may have caused a fairly large dip in performance (<50% decrease in speed at 2 steps, it was above 10 FPS previously).
Whenever I try passing in TinyVAE as the VAE for the TransparentVAEDecoder like the following:
vae_transparent_decoder = TransparentVAEDecoder.from_pretrained('madebyollin/taesd', torch_device='cuda', torch_dtype=torch.float16)
it fails with:
_"ValueError: Cannot load <class 'Layerdiffuse.layer_diffuse.models.modules.TransparentVAEDecoder'> from madebyollin/taesd because the following keys are missing: decoder.midblock.resnets.0.conv1.weight, ...(keeps going...)
I assume since TransparentVAEDecoder inherits/uses AutoencoderKL it's unable to load 'madebyollin/taesd' which requires the AutoencoderTiny() class to instantiate it like:
vae_tmp = AutoencoderTiny.from_pretrained('madebyollin/taesd', torch_device='cuda', torch_dtype=torch.float16)
This is another issue though, let me know if you're interested in assisting, and I can open another and you can close this one.
Maybe a "TransparentTinyVAEDecoder" is needed so TinyVAE can be used? Or maybe there is something easier that I am missing -- either way I think getting it integrated might be the next step (if not a useful one, as it's commonly used for optimization).
Thanks again, and I'll keep an eye on my email if you're interested in assisting with getting TinyVAE working (or at least determining if isn't feasible) with your TransparentVAEDecoder class.
Hi @WyattAutomation , this is totally doable copy the original TransparentVAEDecoder, inherit AutoencoderTiny and replace the constructor parameter here https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_tiny.py#L98-L115
class TransparentTinyVAEDecoder(AutoencoderTiny):
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
act_fn: str = "relu",
upsample_fn: str = "nearest",
latent_channels: int = 4,
upsampling_scaling_factor: int = 2,
num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3),
num_decoder_blocks: Tuple[int, ...] = (3, 3, 3, 1),
latent_magnitude: int = 3,
latent_shift: float = 0.5,
force_upcast: bool = False,
scaling_factor: float = 1.0,
shift_factor: float = 0.0,
):
self.mod_number = None
super().__init__(in_channels, out_channels, encoder_block_out_channels, decoder_block_out_channels, act_fn, upsample_fn, latent_channels, upsampling_scaling_factor, num_encoder_blocks, num_decoder_blocks, latent_magnitude, latent_shift, force_upcast, scaling_factor, shift_factor)
def set_transparent_decoder(self, sd, mod_number=1):
model = UNet1024(in_channels=3, out_channels=4)
model.load_state_dict(sd, strict=True)
model.to(device=self.device, dtype=self.dtype)
model.eval()
self.transparent_decoder = model
self.mod_number = mod_number
def estimate_single_pass(self, pixel, latent):
y = self.transparent_decoder(pixel, latent)
return y
def estimate_augmented(self, pixel, latent):
args = [
[False, 0], [False, 1], [False, 2], [False, 3], [True, 0], [True, 1], [True, 2], [True, 3],
]
result = []
for flip, rok in tqdm(args):
feed_pixel = pixel.clone()
feed_latent = latent.clone()
if flip:
feed_pixel = torch.flip(feed_pixel, dims=(3,))
feed_latent = torch.flip(feed_latent, dims=(3,))
feed_pixel = torch.rot90(feed_pixel, k=rok, dims=(2, 3))
feed_latent = torch.rot90(feed_latent, k=rok, dims=(2, 3))
eps = self.estimate_single_pass(feed_pixel, feed_latent).clip(0, 1)
eps = torch.rot90(eps, k=-rok, dims=(2, 3))
if flip:
eps = torch.flip(eps, dims=(3,))
result += [eps]
result = torch.stack(result, dim=0)
median = torch.median(result, dim=0).values
return median
def decode(self, z: torch.Tensor, return_dict: bool = True, generator=None) -> Union[DecoderOutput, torch.Tensor]:
pixel = super().decode(z, return_dict=False, generator=generator)[0]
pixel = pixel / 2 + 0.5
result_pixel = []
for i in range(int(z.shape[0])):
if self.mod_number is None or (self.mod_number != 1 and i % self.mod_number != 0):
img = torch.cat((pixel[i:i+1], torch.ones_like(pixel[i:i+1,:1,:,:])), dim=1)
result_pixel.append(img)
continue
y = self.estimate_augmented(pixel[i:i+1], z[i:i+1])
y = y.clip(0, 1).movedim(1, -1)
alpha = y[..., :1]
fg = y[..., 1:]
B, H, W, C = fg.shape
cb = checkerboard(shape=(H // 64, W // 64))
cb = cv2.resize(cb, (W, H), interpolation=cv2.INTER_NEAREST)
cb = (0.5 + (cb - 0.5) * 0.1)[None, ..., None]
cb = torch.from_numpy(cb).to(fg)
png = torch.cat([fg, alpha], dim=3)
png = png.permute(0, 3, 1, 2)
result_pixel.append(png)
result_pixel = torch.cat(result_pixel, dim=0)
result_pixel = (result_pixel - 0.5) * 2
if not return_dict:
return (result_pixel, )
return DecoderOutput(sample=result_pixel)
Hi @WyattAutomation , this is totally doable copy the original TransparentVAEDecoder, inherit AutoencoderTiny and replace the constructor parameter here https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_tiny.py#L98-L115
class TransparentTinyVAEDecoder(AutoencoderTiny): @register_to_config def __init__( self, in_channels: int = 3, out_channels: int = 3, encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64), decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64), act_fn: str = "relu", upsample_fn: str = "nearest", latent_channels: int = 4, upsampling_scaling_factor: int = 2, num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3), num_decoder_blocks: Tuple[int, ...] = (3, 3, 3, 1), latent_magnitude: int = 3, latent_shift: float = 0.5, force_upcast: bool = False, scaling_factor: float = 1.0, shift_factor: float = 0.0, ): self.mod_number = None super().__init__(in_channels, out_channels, encoder_block_out_channels, decoder_block_out_channels, act_fn, upsample_fn, latent_channels, upsampling_scaling_factor, num_encoder_blocks, num_decoder_blocks, latent_magnitude, latent_shift, force_upcast, scaling_factor, shift_factor) def set_transparent_decoder(self, sd, mod_number=1): model = UNet1024(in_channels=3, out_channels=4) model.load_state_dict(sd, strict=True) model.to(device=self.device, dtype=self.dtype) model.eval() self.transparent_decoder = model self.mod_number = mod_number def estimate_single_pass(self, pixel, latent): y = self.transparent_decoder(pixel, latent) return y def estimate_augmented(self, pixel, latent): args = [ [False, 0], [False, 1], [False, 2], [False, 3], [True, 0], [True, 1], [True, 2], [True, 3], ] result = [] for flip, rok in tqdm(args): feed_pixel = pixel.clone() feed_latent = latent.clone() if flip: feed_pixel = torch.flip(feed_pixel, dims=(3,)) feed_latent = torch.flip(feed_latent, dims=(3,)) feed_pixel = torch.rot90(feed_pixel, k=rok, dims=(2, 3)) feed_latent = torch.rot90(feed_latent, k=rok, dims=(2, 3)) eps = self.estimate_single_pass(feed_pixel, feed_latent).clip(0, 1) eps = torch.rot90(eps, k=-rok, dims=(2, 3)) if flip: eps = torch.flip(eps, dims=(3,)) result += [eps] result = torch.stack(result, dim=0) median = torch.median(result, dim=0).values return median def decode(self, z: torch.Tensor, return_dict: bool = True, generator=None) -> Union[DecoderOutput, torch.Tensor]: pixel = super().decode(z, return_dict=False, generator=generator)[0] pixel = pixel / 2 + 0.5 result_pixel = [] for i in range(int(z.shape[0])): if self.mod_number is None or (self.mod_number != 1 and i % self.mod_number != 0): img = torch.cat((pixel[i:i+1], torch.ones_like(pixel[i:i+1,:1,:,:])), dim=1) result_pixel.append(img) continue y = self.estimate_augmented(pixel[i:i+1], z[i:i+1]) y = y.clip(0, 1).movedim(1, -1) alpha = y[..., :1] fg = y[..., 1:] B, H, W, C = fg.shape cb = checkerboard(shape=(H // 64, W // 64)) cb = cv2.resize(cb, (W, H), interpolation=cv2.INTER_NEAREST) cb = (0.5 + (cb - 0.5) * 0.1)[None, ..., None] cb = torch.from_numpy(cb).to(fg) png = torch.cat([fg, alpha], dim=3) png = png.permute(0, 3, 1, 2) result_pixel.append(png) result_pixel = torch.cat(result_pixel, dim=0) result_pixel = (result_pixel - 0.5) * 2 if not return_dict: return (result_pixel, ) return DecoderOutput(sample=result_pixel)
Excellent! I will give this a try and let you know how it goes
Just an update!
Your class works with only 2 minimal adjustments -- I think maybe our diffusers versions are different idk but I had to remove a couple params from the init and it works great!
That said, I was still running at about 8FPS, at 2 steps. This is fine as I was expecting to have to get 1-step working again.
Without layerdiffuse, I can do a really ugly single step ControlNet at like 12FPS.
So, yesterday, I took it upon myself to revisit optimization. I integrated an obscure pretrained 48k step DMD model with no downloads and no model card off of huggingface into InstaFlow -- with one ControlNet it went from 12FPS to 30FPS! Yes that is right, THIRTY frames per second.
I will follow up with proof of this but it was like 4:45AM when I finally got that working so I haven't recorded demos yet.
Anyway I plan to add your now working OneDiff+TinyVAE compatible LayerDiffusion to this crazy-fast optimized pipeline. From how it looks, it should achieve north of 20FPS when I try it out -- I'll keep you posted!
Once I have this working I can do a PR for the code.
Also, as a treat, I can throw in my pipeline code as an example of the model combos needed to get the speed as high as I did. Even my old multicontrolnet went fron 10FPS to 22FPS because of that random obscure 48k DMD model I found. Once I have layerdiffuse working I'll post a minimal example here.
I had to figure out by trial and error that model was just a UNet only (I think). It's from the same person that released the unofficial DMD I used, but instead of "1k" at the end of the name it had "48k".
More "k's" gotta be a good thing right?..
Apparently it's a REALLY good thing. I am getting a solid 30FPS when using it as the UNet for an InstaFlow class. When I use 2 ControlNets and 3 LoRAS, it still gets 21FPS, check it out https://github.com/gnobitab/InstaFlow/issues/8#issuecomment-2363218338
I don't think anyone realizes this exists yet. I am fairly sure if you just run txt2img with this combo and without ControlNet it will run at least 45FPS probably higher.
I am going to make several backups of this 48k model. It's all public so it'll get out there anyway, but here is that UNet model I stumbled on. My downloads are the only ones of it, but I wanted to share this information with you for helping me. I will share my pipeline code that I used with InstaFlow here in a minute, but when you get a chance clone this repo and save a backup of it somewhere. This is hands down the fastest UNet model for diffusion I have ever seen; idk if I am over-reacting or if this is known but have you heard of anyone getting 40+ FPS yet?
I hope they didn't release it accidentally but it's out there. I don't think they did I just think they didn't announce it, either way here is this golden goose of a UNet model. Load it seperately and pass it into the "unet" of a pipeline and then dial it down to 1 step; I about choked on my coffee when I saw 32FPS: https://huggingface.co/aaronb/dreamshaper-8-dmd-48k/tree/main
https://github.com/lllyasviel/LayerDiffuse/issues/32#issuecomment-2352046393