omerbt / TokenFlow

Official Pytorch Implementation for "TokenFlow: Consistent Diffusion Features for Consistent Video Editing" presenting "TokenFlow" (ICLR 2024)
https://diffusion-tokenflow.github.io
MIT License
1.56k stars 135 forks source link

SD XL Integration #39

Open irvansian opened 9 months ago

irvansian commented 9 months ago

Is it possible to integrate the latest SD XL to the stable diffuion option?

Zeldalina commented 5 months ago

同问,目前测试直接将SD XL模型替换会发生如下报错: Traceback (most recent call last): File "/root/TokenFlow/preprocess-Longer-n-HR_video.py", line 372, in prep(opt) File "/root/TokenFlow/preprocess-Longer-n-HR_video.py", line 334, in prep recon_frames = model.extract_latents( File "/root/miniconda3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, kwargs) File "/root/TokenFlow/preprocess-Longer-n-HR_video.py", line 290, in extract_latents inverted_x = self.ddim_inversion(cond, File "/root/miniconda3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, *kwargs) File "/root/TokenFlow/preprocess-Longer-n-HR_video.py", line 238, in ddim_inversion eps = self.unet(model_input, t, encoder_hidden_states=cond_batch).sample if self.sd_version != 'ControlNet' \ File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, kwargs) File "/root/miniconda3/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py", line 1216, in forward sample, res_samples = downsample_block( File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) File "/root/miniconda3/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_blocks.py", line 1279, in forward hidden_states = attn( File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, *kwargs) File "/root/miniconda3/lib/python3.10/site-packages/diffusers/models/transformers/transformer_2d.py", line 397, in forward hidden_states = block( File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, kwargs) File "/root/miniconda3/lib/python3.10/site-packages/diffusers/models/attention.py", line 366, in forward attn_output = self.attn2( File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) File "/root/miniconda3/lib/python3.10/site-packages/diffusers/models/attention_processor.py", line 522, in forward return self.processor( File "/root/miniconda3/lib/python3.10/site-packages/diffusers/models/attention_processor.py", line 1266, in call key = attn.to_k(encoder_hidden_states) File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, **kwargs) File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward return F.linear(input, self.weight, self.bias) RuntimeError: mat1 and mat2 shapes cannot be multiplied (20480x4096 and 1024x320)