In the current implementation, the forward() method is generic for train or eval mode. In some case, we need to have not only the loss but the prediction on output that allow to compute extra features like the SDR metric during the validation step.
Because the loss function code is common for BSRoformer and MelBandRoformer classes, maybe that can be better create a new class like MultiResLoss for a maximum of flexibility:
import torch
import torch.nn.functional as F
from einops import rearrange
from beartype import beartype
from beartype.typing import Tuple
class MultiResLoss():
@beartype
def __init__(
self,
num_stems,
stft_n_fft,
multi_stft_resolution_loss_weight = 1.,
multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
multi_stft_hop_size = 147,
multi_stft_normalized = False
):
self.num_stems = num_stems
self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
self.multi_stft_n_fft = stft_n_fft
self.multi_stft_kwargs = dict(
hop_length = multi_stft_hop_size,
normalized = multi_stft_normalized
)
def __call__(
self,
predict,
targets,
return_loss_breakdown = False
):
if self.num_stems > 1:
assert targets.ndim == 4 and targets.shape[1] == self.num_stems
if targets.ndim == 2:
targets = rearrange(targets, '... t -> ... 1 t')
targets = targets[..., :predict.shape[-1]] # protect against lost length on istft
loss = F.l1_loss(predict, targets)
multi_stft_resolution_loss = 0.
for window_size in self.multi_stft_resolutions_window_sizes:
res_stft_kwargs = dict(
n_fft = max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
win_length = window_size,
return_complex = True,
**self.multi_stft_kwargs,
)
predict_Y = torch.stft(rearrange(predict, '... s t -> (... s) t'), **res_stft_kwargs)
targets_Y = torch.stft(rearrange(targets, '... s t -> (... s) t'), **res_stft_kwargs)
multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(predict_Y, targets_Y)
weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
total_loss = loss + weighted_multi_resolution_loss
if not return_loss_breakdown:
return total_loss
return total_loss, (loss, multi_stft_resolution_loss)
In the same spirit, a little refactoring could be to create a new file for the common classes :
In the current implementation, the
forward()
method is generic for train or eval mode. In some case, we need to have not only the loss but the prediction on output that allow to compute extra features like the SDR metric during the validation step.Because the loss function code is common for
BSRoformer
andMelBandRoformer
classes, maybe that can be better create a new class likeMultiResLoss
for a maximum of flexibility:In the same spirit, a little refactoring could be to create a new file for the common classes :
That can be easier for future change in the code?