IanYeung / MGLD-VSR

Code for ECCV 2024 Paper "Motion-Guided Latent Diffusion for Temporally Consistent Real-world Video Super-resolution"
110 stars 3 forks source link

RuntimeError: einsum(): operands do not broadcast with remapped shapes [original->remapped]: [25, 4096, 64]->[25, 4096, 1, 64] [5, 77, 64]->[5, 1, 77, 64] #13

Open ZHAOZHIHAO opened 5 months ago

ZHAOZHIHAO commented 5 months ago


While running

python scripts/vsr_val_ddpm_text_T_vqganfin_oldcanvas_tile.py \
  --config configs/mgldvsr/mgldvsr_512_realbasicvsr_deg.yaml \
  --ckpt ./weights/mgldvsr_unet.ckpt \
  --vqgan_ckpt ./weights/video_vae_cfw.ckpt \
  --seqs-path ./input_seqs \
  --outdir ./building_output \
  --ddpm_steps 50 \
  --dec_w 1.0 \
  --colorfix_type adain \
  --select_idx 0 \
  --n_gpus 1

I got the following error. Do you know what's wrong here? I already tried different image sequences and this error always happens. Btw, I followed this instruction to run your code, https://gist.github.com/meisa233/1549bb95c5c130e3a93fcab17c83e931

Global seed set to 42                                                                                                                                                                 
>>>>>>>>>>color correction>>>>>>>>>>>                                                                                                                                                 
Use adain color correction                                                                                                                                                            
Loading model from ././weights/mgldvsr_unet.ckpt  
Global Step: 42000                                                                                                                                                          [978/1869]
LatentDiffusionVSRTextWT: Running in eps-prediction mode
Setting up MemoryEfficientSelfAttention. Query dim is 1280, using 20 heads.
DiffusionWrapper has 935.32 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 64, 64) = 16384 dimensions.
making attention of type 'vanilla' with 512 in_channels
Encoder Restored from /flash/zhihaoz/vsr/MGLD-VSR/weights/v2-1_512-ema-pruned.ckpt with 0 missing and 1242 unexpected keys
Restored from /flash/zhihaoz/vsr/MGLD-VSR/weights/v2-1_512-ema-pruned.ckpt with 588 missing and 38 unexpected keys
Segment shape:  torch.Size([5, 3, 2048, 2048])                                                                                                                                        
Segment shape:  torch.Size([5, 3, 2048, 2048])                                                                                                                                        
Sampling:   0%|                                                                                                                                                 | 0/2 [00:00<?, ?it/s]
seq:  road_512 seg:  0 size:  torch.Size([5, 3, 2048, 2048])                                                                                                                          
/flash/zhihaoz/conda/envs/mgldvsr/lib/python3.9/site-packages/torch/functional.py:478: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing a
rgument. (Triggered internally at  /opt/conda/conda-bld/pytorch_1659484809662/work/aten/src/ATen/native/TensorShape.cpp:2894.)                                                        
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]                                                                                                                
Global seed set to 42                                                                                                                                                                 
Sampling t:   0%|                                                                                                                                              | 0/50 [00:00<?, ?it/s]
Sampling:   0%|  

Traceback (most recent call last):                                                                                                                                                    
  File "/flash/zhihaoz/vsr/MGLD-VSR/scripts/vsr_val_ddpm_text_T_vqganfin_oldcanvas_tile.py", line 565, in <module>                                                                    
  File "/flash/zhihaoz/vsr/MGLD-VSR/scripts/vsr_val_ddpm_text_T_vqganfin_oldcanvas_tile.py", line 459, in main                                                                        
    samples, _ = model.sample_canvas(cond=semantic_c,                                                                                                                                 
  File "/flash/zhihaoz/conda/envs/mgldvsr/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context                                                      
    return func(*args, **kwargs)                                                                                                                                                      
  File "/flash/zhihaoz/vsr/MGLD-VSR/ldm/models/diffusion/ddpm.py", line 4735, in sample_canvas                                                                                        
    return self.p_sample_loop_canvas(cond,                                                                                                                                            
  File "/flash/zhihaoz/conda/envs/mgldvsr/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context                                                      
    return func(*args, **kwargs) 
  File "/flash/zhihaoz/vsr/MGLD-VSR/ldm/models/diffusion/ddpm.py", line 4673, in p_sample_loop_canvas                                                                                 
    img = self.p_sample_canvas(img, cond, struct_cond, ts, guidance_scale=guidance_scale,                                                                                             
  File "/flash/zhihaoz/conda/envs/mgldvsr/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context                                                      
    return func(*args, **kwargs)                                                                                                                                                      
  File "/flash/zhihaoz/vsr/MGLD-VSR/ldm/models/diffusion/ddpm.py", line 4390, in p_sample_canvas                                                                                      
    outputs = self.p_mean_variance_canvas(x=x, c=c, struct_cond=struct_cond, t=t, clip_denoised=clip_denoised,                                                                        
  File "/flash/zhihaoz/vsr/MGLD-VSR/ldm/models/diffusion/ddpm.py", line 4256, in p_mean_variance_canvas                                                                               
    model_out = self.apply_model(input_list, t_in[:input_list.size(0)], c[:input_list.size(0)], struct_cond_input, return_ids=return_codebook_ids)                                    
  File "/flash/zhihaoz/vsr/MGLD-VSR/ldm/models/diffusion/ddpm.py", line 4080, in apply_model                                                                                          
    x_recon = self.model(x_noisy, t, **cond)                                                                                                                                          
  File "/flash/zhihaoz/conda/envs/mgldvsr/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl                                                           
    return forward_call(*input, **kwargs)                                                                                                                                             
  File "/flash/zhihaoz/vsr/MGLD-VSR/ldm/models/diffusion/ddpm.py", line 4927, in forward                                                                                              
    out = self.diffusion_model(x, t, context=cc, struct_cond=struct_cond)                                                                                                             
  File "/flash/zhihaoz/conda/envs/mgldvsr/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl                                                           
    return forward_call(*input, **kwargs)                                                                                                                                             
  File "/flash/zhihaoz/vsr/MGLD-VSR/ldm/modules/diffusionmodules/openaimodel.py", line 2303, in forward                                                                               
    h = module(h, emb, context, struct_cond)                                                                                                                                          
  File "/flash/zhihaoz/conda/envs/mgldvsr/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/flash/zhihaoz/vsr/MGLD-VSR/ldm/modules/diffusionmodules/openaimodel.py", line 145, in forward
    x = layer(x, context)
  File "/flash/zhihaoz/conda/envs/mgldvsr/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/flash/zhihaoz/vsr/MGLD-VSR/ldm/modules/attention.py", line 540, in forward
    x = block(x, context=context[i])
  File "/flash/zhihaoz/conda/envs/mgldvsr/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/flash/zhihaoz/vsr/MGLD-VSR/ldm/modules/attention.py", line 429, in forward
    return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
  File "/flash/zhihaoz/vsr/MGLD-VSR/ldm/modules/diffusionmodules/util.py", line 116, in checkpoint
    return func(*inputs)
  File "/flash/zhihaoz/vsr/MGLD-VSR/ldm/modules/attention.py", line 433, in _forward
    x = self.attn2(self.norm2(x), context=context) + x
  File "/flash/zhihaoz/conda/envs/mgldvsr/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/flash/zhihaoz/vsr/MGLD-VSR/ldm/modules/attention.py", line 246, in forward
    sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
  File "/flash/zhihaoz/conda/envs/mgldvsr/lib/python3.9/site-packages/torch/functional.py", line 360, in einsum
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
RuntimeError: einsum(): operands do not broadcast with remapped shapes [original->remapped]: [25, 4096, 64]->[25, 4096, 1, 64] [5, 77, 64]->[5, 1, 77, 64]
VingtDylan commented 4 months ago

same issue, any advice?

IanYeung commented 4 months ago

Hi. This might due to the environment problem and I have also encountered the problem on some new environment with new packages. You may refer to issue #9.