Closed mustass closed 1 year ago
If we are to adopt the inference script from our inspiration repo, then we need to do the following:
Create a dataset & dataloader for predictions. Basically a thing that is going to return something equivalent to:
training_noise_schedule = np.array(model.params.noise_schedule)
inference_noise_schedule = np.array(model.params.inference_noise_schedule) if fast_sampling else training_noise_schedule
talpha = 1 - training_noise_schedule
talpha_cum = np.cumprod(talpha)
beta = inference_noise_schedule
alpha = 1 - beta
alpha_cum = np.cumprod(alpha)
T = []
for s in range(len(inference_noise_schedule)):
for t in range(len(training_noise_schedule) - 1):
if talpha_cum[t+1] <= alpha_cum[s] <= talpha_cum[t]:
twiddle = (talpha_cum[t]**0.5 - alpha_cum[s]**0.5) / (talpha_cum[t]**0.5 - talpha_cum[t+1]**0.5)
T.append(t + twiddle)
break
T = np.array(T, dtype=np.float32)
This last T
is what is going to be an element of a batch.
Our Lightning module needs to implement a predict_step
, that is going to return something equivalent to:
if not model.params.unconditional:
if len(spectrogram.shape) == 2:# Expand rank 2 tensors by adding a batch dimension.
spectrogram = spectrogram.unsqueeze(0)
spectrogram = spectrogram.to(device)
audio = torch.randn(spectrogram.shape[0], model.params.hop_samples * spectrogram.shape[-1], device=device)
else:
audio = torch.randn(1, params.audio_len, device=device)
noise_scale = torch.from_numpy(alpha_cum**0.5).float().unsqueeze(1).to(device)
for n in range(len(alpha) - 1, -1, -1):
c1 = 1 / alpha[n]**0.5
c2 = beta[n] / (1 - alpha_cum[n])**0.5
audio = c1 * (audio - c2 * model(audio, torch.tensor([T[n]], device=audio.device), spectrogram).squeeze(1))
if n > 0:
noise = torch.randn_like(audio)
sigma = ((1.0 - alpha_cum[n-1]) / (1.0 - alpha_cum[n]) * beta[n])**0.5
audio += sigma * noise
audio = torch.clamp(audio, -1.0, 1.0)
return audio, model.params.sample_rate
Lastly, we need to make it work effortlessly given a pre-trained model. This is quite easy.
We have a model trained, what now?