csteinmetz1 / auraloss

Collection of audio-focused loss functions in PyTorch
Apache License 2.0
695 stars 66 forks source link

SISDR loss #28

Closed fncode246 closed 2 years ago

fncode246 commented 3 years ago

Hi, Thanks for sharing this useful repo. I want to use SISDR as loss function during training. here is an example of target and input: inputs = torch.rand(8,1,44100) target = torch.rand(8,1,44100) loss=SISDRLoss(inputs,target) print(loss)

But it only returns :
SISDRLoss() and I can't calculate SI SDR loss value with this. Please help me... Thanks

csteinmetz1 commented 3 years ago

Hi, thanks for checking out auraloss.

I think the issue is that you must first instantiate the SISDRLoss module first. For example:

inputs = torch.rand(8,1,44100)
target = torch.rand(8,1,44100)
sisdr_loss = SISDRLoss() # instantiate the loss module
loss = sisdr_loss(inputs,target) # measure the loss
print(loss)

Let me know if this addresses your issue.

fncode246 commented 3 years ago

Yes, it works. Thanks

fncode246 commented 3 years ago

Is there any way to use SI-SDR loss with PIT (permutation invariant of sources during training source separation model)?

csteinmetz1 commented 2 years ago

Sorry for the delay on responding to this. Currently we don't have support for PIT and likely won't add it. You could try using auraloss loss function with the PIT wrapper from Asteroid (https://github.com/asteroid-team/pytorch-pit/blob/master/torch_pit/pit_wrapper.py).