archinetai / audio-diffusion-pytorch

Audio generation using diffusion models, in PyTorch.
MIT License
1.92k stars 167 forks source link

RuntimeError: The size of tensor a (91) must match the size of tensor b (90) at non-singleton dimension 2 #83

Open erikqu opened 5 months ago

erikqu commented 5 months ago

It seems like all my batches have some underlying issue where they're all off by one, I've seen other issues opened about this, but no proper explanation, could I get some help on this?

Failed during forward The size of tensor a (91) must match the size of tensor b (90) at non-singleton dimension 2

Verified text and wavs are both the batch size (16), all wavs are padded in this case to 84480.

RuntimeError: The size of tensor a (91) must match the size of tensor b (90) at non-singleton dimension 2                                                                                                                                                                                
Failed during forward The size of tensor a (91) must match the size of tensor b (90) at non-singleton dimension 2                                                                                                                                                                        
The size of tensor a (91) must match the size of tensor b (90) at non-singleton dimension 2                                                                                                                                                                                              
torch.Size([16, 84480]) 16                                                                                                                  
Traceback (most recent call last):                                                                                                          
  File "/mnt/nvme/programs/qTTS/train_ttv_v1.py", line 184, in train_and_evaluate                                                           
    loss_gen_all = net_g(                                                                                                                   
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1523, in forward                                                                                                                                                                                 
    else self._run_ddp_forward(*inputs, **kwargs)                                                                                           
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/distributed.py", line 1359, in _run_ddp_forward                                                                                                                                                                        
    return self.module(*inputs, **kwargs)  # type: ignore[index]                                                                            
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/audio_diffusion_pytorch/models.py", line 40, in forward                                                                                                                                                                                  
    return self.diffusion(*args, **kwargs)                                                                                                  
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/audio_diffusion_pytorch/diffusion.py", line 93, in forward                                                                                                                                                                               
    v_pred = self.net(x_noisy, sigmas, **kwargs)                                                                                            
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/a_unet/blocks.py", line 63, in forward                                                      
    return forward_fn(*args, **kwargs)                                                                                                      
  File "/usr/local/lib/python3.10/dist-packages/a_unet/blocks.py", line 594, in forward                                                                                                                                                                                                  
    return net(x, features=features, **kwargs)                                                                                              
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl           
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                                                                                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/a_unet/blocks.py", line 63, in forward                                                      
    return forward_fn(*args, **kwargs)                                                                                                      
  File "/usr/local/lib/python3.10/dist-packages/a_unet/blocks.py", line 621, in forward                                                     
    return net(x, embedding=text_embedding, **kwargs)                                                                                       
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/a_unet/blocks.py", line 63, in forward                                                                                                                                                                                                   
    return forward_fn(*args, **kwargs)                                                                                                      
  File "/usr/local/lib/python3.10/dist-packages/a_unet/blocks.py", line 552, in forward                                                                                                                                                                                                  
    return net(x, embedding=embedding, **kwargs)                                                                                            
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/a_unet/apex.py", line 431, in forward                                                                                                                                                                                                    
    return self.net(x, features, embedding, channels)  # type: ignore                                                                       
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/a_unet/apex.py", line 382, in forward                                                                                                                                                                                                    
    x = self.block(x, features, embedding, channels)                                                                                        
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                                                                                                                                                                    
    return forward_call(*args, **kwargs)                                                                                                    
  File "/usr/local/lib/python3.10/dist-packages/a_unet/blocks.py", line 77, in forward                                                      
    x = block(x, *args)                                                                                                                     
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl                                                                                                                                                                            
    return self._call_impl(*args, **kwargs)                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl                                              

Followed the example from the README:

  net_g = DiffusionModel(
      net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
      in_channels=1, # U-Net: number of input/output (audio) channels
      channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
      factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
      items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
      attentions=[0, 0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
      attention_heads=8, # U-Net: number of attention heads per attention item
      attention_features=64, # U-Net: number of attention features per attention item
      diffusion_t=VDiffusion, # The diffusion method used
      sampler_t=VSampler, # The diffusion sampler used
      use_text_conditioning=True, # U-Net: enables text conditioning (default T5-base)
      use_embedding_cfg=True, # U-Net: enables classifier free guidance
      embedding_max_length=64, # U-Net: text embedding maximum length (default for T5-base)
      embedding_features=768, # U-Net: text mbedding features (default for T5-base)
      cross_attentions=[0, 0, 0, 1, 1, 1, 1, 1, 1], # U-Net: cross-attention enabled/disabled at each layer
  )