seungwonpark / melgan

MelGAN vocoder (compatible with NVIDIA/tacotron2)
http://swpark.me/melgan/
BSD 3-Clause "New" or "Revised" License
633 stars 116 forks source link

Why remove weight norm? #37

Closed MachineJeff closed 4 years ago

MachineJeff commented 4 years ago

Why remove weight norm in eval? At inference time, weight norm should be kept or removed?

seungwonpark commented 4 years ago

It should be removed at inference time. By using the weight norm, we split the weights into the norm/direction of the weight vector. It helps at training time, but we don't need to calculate the product of norm & direction every time, hence we remove them by multiplying them at first.

MachineJeff commented 4 years ago

Sorry, I still can not understand this operation.

In your code melgan/inference.py line 32


    with torch.no_grad():
        for melpath in tqdm.tqdm(glob.glob(os.path.join(args.input_folder, '*.mel'))):
            mel = torch.load(melpath)
            if len(mel.shape) == 2:
                mel = mel.unsqueeze(0)
            mel = mel.cuda()

            audio = model.inference(mel)
            audio = audio.cpu().detach().numpy()

            out_path = melpath.replace('.mel', '_reconstructed_epoch%04d.wav' % checkpoint['epoch'])
            write(out_path, hp.audio.sampling_rate, audio)

the key is audio = model.inference(mel)

and in melgan/model/generator.py

    def forward(self, mel):
        mel = (mel + 5.0) / 5.0 # roughly normalize spectrogram
        return self.generator(mel)

    def inference(self, mel):
        hop_length = 256
        # pad input mel with zeros to cut artifact
        # see https://github.com/seungwonpark/melgan/issues/8
        zero = torch.full((1, self.mel_channel, 10), -11.5129).to(mel.device)
        mel = torch.cat((mel, zero), dim=2)

        audio = self.forward(mel)
        audio = audio.squeeze() # collapse all dimension except time axis
        audio = audio[:-(hop_length*10)]
        audio = MAX_WAV_VALUE * audio
        audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1)
        audio = audio.short()

        return audio

the inference api call the forward api, then forward api call generator.

See? It's has no business with remove weight norm @seungwonpark

Liujingxiu23 commented 4 years ago

I not clear whether the result with and without remove_weight_norm is the same? If not, which one is right, or better? @MachineJeff @seungwonpark