Closed MachineJeff closed 4 years ago
It should be removed at inference time. By using the weight norm, we split the weights into the norm/direction of the weight vector. It helps at training time, but we don't need to calculate the product of norm & direction every time, hence we remove them by multiplying them at first.
Sorry, I still can not understand this operation.
In your code melgan/inference.py line 32
with torch.no_grad():
for melpath in tqdm.tqdm(glob.glob(os.path.join(args.input_folder, '*.mel'))):
mel = torch.load(melpath)
if len(mel.shape) == 2:
mel = mel.unsqueeze(0)
mel = mel.cuda()
audio = model.inference(mel)
audio = audio.cpu().detach().numpy()
out_path = melpath.replace('.mel', '_reconstructed_epoch%04d.wav' % checkpoint['epoch'])
write(out_path, hp.audio.sampling_rate, audio)
the key is audio = model.inference(mel)
and in melgan/model/generator.py
def forward(self, mel):
mel = (mel + 5.0) / 5.0 # roughly normalize spectrogram
return self.generator(mel)
def inference(self, mel):
hop_length = 256
# pad input mel with zeros to cut artifact
# see https://github.com/seungwonpark/melgan/issues/8
zero = torch.full((1, self.mel_channel, 10), -11.5129).to(mel.device)
mel = torch.cat((mel, zero), dim=2)
audio = self.forward(mel)
audio = audio.squeeze() # collapse all dimension except time axis
audio = audio[:-(hop_length*10)]
audio = MAX_WAV_VALUE * audio
audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1)
audio = audio.short()
return audio
the inference api call the forward api, then forward api call generator.
See? It's has no business with remove weight norm @seungwonpark
I not clear whether the result with and without remove_weight_norm is the same? If not, which one is right, or better? @MachineJeff @seungwonpark
Why remove weight norm in eval? At inference time, weight norm should be kept or removed?