Open CarolineCheng233 opened 2 years ago
Hi, I'm afraid we don't support that. The dataloader we use is somehow complex. It assigns IDs dynamically to each file and it has a wider set of options for different experiments we didn't publish. We don't provide the training code as it would be costly for us to help everyone to address the issues.
If you just want to train the model using the face landmarks and the audio, it's somehow simple to create a dataloader. You just need 2 audios + the landmarks corresponding to the target face. The audios are normalized wrt the absolute max.
The core network is the one called AudioVisualNetwork
. In the forward you can find a line similar to:
audio_feats = self.audio_processor.preprocess_audio(s1,s2)
This will create the mixtures for you. Then you just need to create the gt masks (which are bounded for stability) with the following code. Note that the spectrograms are batch, freq, time, [real,imaginary]
def get_loss_mask(self, logits_mask, sp_mix, sources):
spm = rearrange(sources[0], 'b c f t -> b f t c')
sp_mix = rearrange(sp_mix, 'b c f t -> b f t c')
# Masks are typically bounded when used as loss
# Two types are computed:
# loss_mask is the one which would be used in a loss function
# inference_mask is the one that, being multiplied by the mixture, gives the indepoendent source
if self.complex_enabled:
# loss_mask = torch.permute(self.tanh(logits_mask), [0,2,3,1])
loss_mask = rearrange(self.tanh(logits_mask), 'b c f t -> b f t c')
gt_mask = self.complex_mask(spm, sp_mix)
else:
raise NotImplementedError(f'Not tested')
return loss_mask, gt_mask
def compute_loss(self, pred, gt, weight, loss_criterion='mse'):
if self.complex_enabled or not self.loss_on_mask:
# Complex mask and gt shape BxFxTxC, weight unsqueezing required for broadcasting
# The same applied for direct estimation, in which the mask multiplies the mixture as real + imag
# However ratio masks are applied over the magnitude so no broadcasting is used
weight = weight.unsqueeze(-1)
pred = torch.view_as_real(pred) if not self.loss_on_mask else pred
assert pred.shape == gt.shape, 'Mask computation: Ground truth and predictions has to be the same shape'
if loss_criterion.lower() == 'mse':
if self.weighted_loss:
loss = (weight * (pred - gt).pow(2)).mean()
else:
loss = mse_loss(pred, gt)
elif loss_criterion.lower() == 'l1':
if self.weighted_loss:
loss = (weight * (pred - gt).abs()).mean()
else:
loss = l1_loss(pred, gt)
elif loss_criterion.lower() == 'l1':
raise Exception('I fucked it up with the if statements')
return loss
@staticmethod
def tanh(x):
K = 10
# *(1-torch.exp(-C * x))/(1+torch.exp(-C * x))
# Compute this formula but using torch.tanh to deal with asymptotic values
# Manually coded at https://github.com/vitrioil/Speech-Separation/blob/master/src/models/complex_mask_utils.py
return K * torch.tanh(x)
@torch.no_grad()
def complex_mask(self, sp0, sp_mix, eps=torch.finfo(torch.float32).eps):
# Bibliography about complex masks
# http://homes.sice.indiana.edu/williads/publication_files/williamsonetal.cRM.2016.pdf
assert sp0.shape == sp_mix.shape
sp_mix = sp_mix + eps
mask = complex_division(sp0, sp_mix) / self.n_sources
mask_bounded = self.tanh(mask)
return mask_bounded
I can give you parts of the code but I don't think it's gonna be more clear than this one (which has been polished)
The refinement network is trained once the main stage is trained as described in the paper. We used batch size 10.
Best Juan
the weight is.
weight = torch.log1p(magnitude_sp_mix)
weight = torch.clamp(weight, 1e-3, 10)
Hi @JuanFMontesinos ,
Thank You for the code. Great paper! I am indeed trying to replicate the training.
I had a few questions and would be grateful if you could please answer them:
1) audio_feats = self.audio_processor.preprocess_audio(s1,s2) --> where exactly in this function is the mixture of 2 audios being created? I just see the first file being read here (src[0]): https://github.com/JuanFMontesinos/VoViT/blob/499a689860d832edc3b5c744ef5d62eeddc96d3b/vovit/core/models/production_model.py#L121-L139
Do you have a different preprocess_audio while training? or am I missing something here?
2) How do you create the extra dimensions here? batch, freq, time, [real,imaginary], the [real,imaginary] parts?
Thank You!
Hi @Sreyan88 The code over there is just for inference (it's a simplified function). At training, we create the mixtures by mixing 2 or more audios (depending on the setup). The preprocess_audio func in training looks like:
def preprocess_audio(self, *src: list, real_sample=False, n_sources=2):
"""
Inputs contains the following keys:
audio: the main audio waveform of shape N,M
audio_acmt: the secondary audio waveform of shame N,M
src: If using inference on real mixtures, the mixture audio waveform of shape N,M
"""
self.n_sources = n_sources if real_sample else len(src)
self.n_sources = self.n_sources + 1 if self.remix_input else self.n_sources
with torch.no_grad():
# Inference for a real mixture
if real_sample:
self.n_sources = n_sources
# Inference in case of a real sample
sp_mix_raw = self.wav2sp(src[0]).contiguous() / self.n_sources
if self.downsample_coarse:
sp_mix = sp_mix_raw[:, ::2, ...] # BxFxTx2
elif self.downsample_interp:
raise NotImplementedError
else:
sp_mix = sp_mix_raw
else:
sp = [self.wav2sp(x) for x in src] # Spectrogram shape BxFxTx2
assert len(sp) > 0, f'Empty list of audio inputs '
# This is the main spectrogram, which is the target if single output modality
# Used when remix input up
spm = sp[0]
if self.remix_input:
B = spm.shape[0] # Batch elements
ndim = spm.ndim
coef = (torch.rand(B, *[1 for _ in range(ndim - 1)], device=spm.device) < self.remix_coef).byte()
# sources = sp + [spm.flip(0) * coef]
sources = sp + [torch.roll(spm, shifts=1, dims=0) * coef]
else:
sources = sp
sp_mix_raw = (sum(sources) / self.n_sources).contiguous()
sources_raw = sources
# Downsampling to save memory
if self.downsample_coarse:
# Shape becomes B x F/2 x T x 2
sources = [x[:, ::2, ...] for x in sources_raw]
sp_mix = sp_mix_raw[:, ::2, ...]
elif self.downsample_interp:
raise NotImplementedError
else:
sp_mix = sp_mix_raw
mag = sp_mix.norm(dim=-1) # Magnitude spectrogram BxFxT
if self.weighted_loss:
weight = torch.log1p(mag)
weight = torch.clamp(weight, 1e-3, 10)
else:
weight = None
if self.complex_enabled:
x = rearrange(sp_mix, 'b f t c -> b c f t')
# x = torch.permute(sp_mix, [0,3,1,2])
if not real_sample:
sources = [rearrange(x, 'b f t c -> b c f t') for x in sources]
elif self.log_sp_enabled:
epsilon = 1e-4
x = (mag + epsilon).log().unsqueeze(1)
if not real_sample:
sources = [(x.norm(dim=-1) + epsilon).log().unsqueeze(1) for x in sources]
elif self.mel_enabled:
x = self.sp2mel(mag).unsqueeze(1)
sources = [x.norm(dim=-1).unsqueeze(1) for x in sources]
raise NotImplementedError('Option not implemented in depth. Draft written.')
else:
x = mag.unsqueeze(1)
if not real_sample:
sources = [x.norm(dim=-1).unsqueeze(1) for x in sources]
output = {'mixture': x, 'sp_mix_raw': sp_mix_raw}
if not real_sample:
output.update({'src_sp_raw': sources_raw, 'sources': sources, 'weight': weight})
return output
where the mix is created here:
if self.remix_input:
B = spm.shape[0] # Batch elements
ndim = spm.ndim
coef = (torch.rand(B, *[1 for _ in range(ndim - 1)], device=spm.device) < self.remix_coef).byte()
# sources = sp + [spm.flip(0) * coef]
sources = sp + [torch.roll(spm, shifts=1, dims=0) * coef]
We are basically computing a binary mask created from random coefficients and permuting the elements of the batch.
return_complex=False
in the stft operator. Hi @JuanFMontesinos ,
Thank You so much for this! I will try to keep this discussion alive till I am being able to replicate the training!
Sure @Sreyan88 Note that it takes ~20 days to train in a 3090. At the beggining it may take a while to start converging if you use 2 voices only as it's a hard problem.
Hi @JuanFMontesinos ,
I have two questions and would be grateful if you could answer them!
(1) Can you please provide me with the code for: complex_division()
(2) I want to generate complex masks for speech dereverberation. I wanted to know what will n_sources
in mask = complex_division(sp0, sp_mix) / self.n_sources
be?
Hi @Sreyan88
complex_division
is defined here:
https://github.com/JuanFMontesinos/VoViT/blob/dfd11c2ec5faa09f51cf68ae7ab42e5794941468/vovit/core/models/production_model.py#L68-L73
Regarding the masks, masks are usually defined as: $T(k,l)=M(k,l)\odot S(k,l)$ where $\odot$ denotes the complex product, T is the target source and S is the mixture. If we now guess we have two sources, $S=(T_1+T_2)/2$. Note that sources can be combined in different ways. Taking the average is just a widely used one. Some papers take the maximum or apply a weighted average.
$M=T_1/S=2*\frac{T_1}{T_1+T_2}$
In our case we divide by n_sources to remove that scaling factor. So that our masks are defined as $M=\frac{T_1}{T_1+T_2}$
The interpretation is the following. If you do not divide by the scaling factor, you are making the network to estimate how many sources are there, and compensating to recover the original loudness. If you divide by the scaling factor, the network just isolates the source at its current loudness.
Since we know the number of sources, we just proceeded as in the second case. However it's something very silly and picky, your shouldn't worry about it.
Thank You so much for your reply! Can you also please help me with your stft parameters?
As stated in Sec. 3.2. in the paper, audio is resampled to 16384 Hz and stft is computed with a hop size of 256 and nfft 1022
Can you please provide the code for the training and testing process? Meanwhile, I wonder about the detailed training setting, such as batch size. Thanks a lot!