archinetai / audio-diffusion-pytorch

Audio generation using diffusion models, in PyTorch.
MIT License
1.95k stars 169 forks source link

AssertionError: ClassiferFreeGuidancePlugin requires embedding #71

Open gg4u opened 1 year ago

gg4u commented 1 year ago

Hi, I test the example you gave for conditioning on text, but got error:

# Train model with audio waveforms
audio_wave = torch.randn(1, 2, 2**18) # [batch, in_channels, length]
loss = model(
    audio_wave,
    text=['The audio description'], # Text conditioning, one element per batch
    embedding_mask_proba=0.1 # Probability of masking text with learned embedding (Classifier-Free Guidance Mask)
)
loss.backward()

# Turn noise into new audio sample with diffusion
noise = torch.randn(1, 2, 2**18)
sample = model.sample(
    noise,
    text=['The audio description'],
    embedding_scale=5.0, # Higher for more text importance, suggested range: 1-15 (Classifier-Free Guidance Scale)
    num_steps=2 # Higher for better quality, suggested num_steps: 10-100
)

Error:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[49], line 3
      1 # Train model with audio waveforms
      2 audio_wave = torch.randn(1, 2, 2**18) # [batch, in_channels, length]
----> 3 loss = model(
      4     audio_wave,
      5     text=['The audio description'], # Text conditioning, one element per batch
      6     embedding_mask_proba=0.1 # Probability of masking text with learned embedding (Classifier-Free Guidance Mask)
      7 )
      8 loss.backward()
     10 # Turn noise into new audio sample with diffusion

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/audio_diffusion_pytorch/models.py:40, in DiffusionModel.forward(self, *args, **kwargs)
     39 def forward(self, *args, **kwargs) -> Tensor:
---> 40     return self.diffusion(*args, **kwargs)

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/audio_diffusion_pytorch/diffusion.py:93, in VDiffusion.forward(self, x, **kwargs)
     91 v_target = alphas * noise - betas * x
     92 # Predict velocity and return loss
---> 93 v_pred = self.net(x_noisy, sigmas, **kwargs)
     94 return F.mse_loss(v_pred, v_target)

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/a_unet/blocks.py:63, in Module.<locals>.Module.forward(self, *args, **kwargs)
     62 def forward(self, *args, **kwargs):
---> 63     return forward_fn(*args, **kwargs)

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/a_unet/blocks.py:594, in TimeConditioningPlugin.<locals>.Net.<locals>.forward(x, time, features, **kwargs)
    592 # Merge time features with features if provided
    593 features = features + time_features if exists(features) else time_features
--> 594 return net(x, features=features, **kwargs)

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/a_unet/blocks.py:63, in Module.<locals>.Module.forward(self, *args, **kwargs)
     62 def forward(self, *args, **kwargs):
---> 63     return forward_fn(*args, **kwargs)

File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/a_unet/blocks.py:534, in ClassifierFreeGuidancePlugin.<locals>.Net.<locals>.forward(x, embedding, embedding_scale, embedding_mask_proba, **kwargs)
    526 def forward(
    527     x: Tensor,
    528     embedding: Optional[Tensor] = None,
   (...)
    531     **kwargs,
    532 ):
    533     msg = "ClassiferFreeGuidancePlugin requires embedding"
--> 534     assert exists(embedding), msg
    535     b, device = embedding.shape[0], embedding.device
    536     embedding_mask = fixed_embedding(embedding)

AssertionError: ClassiferFreeGuidancePlugin requires embedding

Is it about dependencies ? What dependencies am I supposed to install ?

P.s. can you please show two simple colab examples:

I am trying to understand how to condition on text to validate research idea in bioacoustics, but not have a strong foundations to well understand yet your code, so a tutorial would be really helpful.