Hi, I test the example you gave for conditioning on text, but got error:
# Train model with audio waveforms
audio_wave = torch.randn(1, 2, 2**18) # [batch, in_channels, length]
loss = model(
audio_wave,
text=['The audio description'], # Text conditioning, one element per batch
embedding_mask_proba=0.1 # Probability of masking text with learned embedding (Classifier-Free Guidance Mask)
)
loss.backward()
# Turn noise into new audio sample with diffusion
noise = torch.randn(1, 2, 2**18)
sample = model.sample(
noise,
text=['The audio description'],
embedding_scale=5.0, # Higher for more text importance, suggested range: 1-15 (Classifier-Free Guidance Scale)
num_steps=2 # Higher for better quality, suggested num_steps: 10-100
)
Error:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[49], line 3
1 # Train model with audio waveforms
2 audio_wave = torch.randn(1, 2, 2**18) # [batch, in_channels, length]
----> 3 loss = model(
4 audio_wave,
5 text=['The audio description'], # Text conditioning, one element per batch
6 embedding_mask_proba=0.1 # Probability of masking text with learned embedding (Classifier-Free Guidance Mask)
7 )
8 loss.backward()
10 # Turn noise into new audio sample with diffusion
File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/audio_diffusion_pytorch/models.py:40, in DiffusionModel.forward(self, *args, **kwargs)
39 def forward(self, *args, **kwargs) -> Tensor:
---> 40 return self.diffusion(*args, **kwargs)
File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/audio_diffusion_pytorch/diffusion.py:93, in VDiffusion.forward(self, x, **kwargs)
91 v_target = alphas * noise - betas * x
92 # Predict velocity and return loss
---> 93 v_pred = self.net(x_noisy, sigmas, **kwargs)
94 return F.mse_loss(v_pred, v_target)
File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/a_unet/blocks.py:63, in Module.<locals>.Module.forward(self, *args, **kwargs)
62 def forward(self, *args, **kwargs):
---> 63 return forward_fn(*args, **kwargs)
File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/a_unet/blocks.py:594, in TimeConditioningPlugin.<locals>.Net.<locals>.forward(x, time, features, **kwargs)
592 # Merge time features with features if provided
593 features = features + time_features if exists(features) else time_features
--> 594 return net(x, features=features, **kwargs)
File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/a_unet/blocks.py:63, in Module.<locals>.Module.forward(self, *args, **kwargs)
62 def forward(self, *args, **kwargs):
---> 63 return forward_fn(*args, **kwargs)
File /data0/home/h21/luas6629/dummy/lib/python3.10/site-packages/a_unet/blocks.py:534, in ClassifierFreeGuidancePlugin.<locals>.Net.<locals>.forward(x, embedding, embedding_scale, embedding_mask_proba, **kwargs)
526 def forward(
527 x: Tensor,
528 embedding: Optional[Tensor] = None,
(...)
531 **kwargs,
532 ):
533 msg = "ClassiferFreeGuidancePlugin requires embedding"
--> 534 assert exists(embedding), msg
535 b, device = embedding.shape[0], embedding.device
536 embedding_mask = fixed_embedding(embedding)
AssertionError: ClassiferFreeGuidancePlugin requires embedding
Is it about dependencies ?
What dependencies am I supposed to install ?
P.s. can you please show two simple colab examples:
to train on own wav files
use pretrained networks and finetune on own wav files
I am trying to understand how to condition on text to validate research idea in bioacoustics, but not have a strong foundations to well understand yet your code, so a tutorial would be really helpful.
Hi, I test the example you gave for conditioning on text, but got error:
Error:
Is it about dependencies ? What dependencies am I supposed to install ?
P.s. can you please show two simple colab examples:
I am trying to understand how to condition on text to validate research idea in bioacoustics, but not have a strong foundations to well understand yet your code, so a tutorial would be really helpful.