maum-ai / voicefilter

Unofficial PyTorch implementation of Google AI's VoiceFilter system
http://swpark.me/voicefilter
1.08k stars 227 forks source link

[WIP] implement power low compression loss #18

Open stegben opened 5 years ago

stegben commented 5 years ago

fixed: #14

Still in progress, not yet tested. I have to wait until my initial run finish...

xiaozhuo12138 commented 4 years ago

Whether modifying the loss function will improve performance

stegben commented 4 years ago

Sorry my 1080ti is too weak. Could anybody with better hardware help me run the experiment?

xiaozhuo12138 commented 4 years ago

My own server is too slow. But it is found that the loss function is different from what the author said. Now just use the amplitude spectrum

linzwatt commented 4 years ago

I have access to nvidia V100 GPUs, and I am currently training a modified version of this model. I can test this PR if you like

stegben commented 4 years ago

@linzwatt That would be awesome! Thanks a lot

kwikwag commented 4 years ago

hey @linzwatt any results?

kwikwag commented 4 years ago

I am not sure if I ran it right -- I used the config as provided and I get this:

2020-02-29 13:51:24,952 - INFO - Starting new training run
../torch/csrc/utils/python_arg_parser.cpp:698: UserWarning: This overload of add_ is deprecated:
        add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
        add_(Tensor other, Number alpha)
2020-02-29 13:51:34,106 - INFO - Wrote summary at step 1
2020-02-29 13:51:37,200 - ERROR - Loss exploded to nan at step 2!
2020-02-29 13:51:37,459 - INFO - Exiting due to exception: Loss exploded
Traceback (most recent call last):
  File "/mnt/disks/voicefilter-1/data/voicefilter/utils/train.py", line 91, in train
    raise Exception("Loss exploded")
Exception: Loss exploded
Edresson commented 4 years ago

@kwikwag I think that its implementation is not following the formula of the paper, because in the second term I believe that it is not necessary to use torch.clamp (x, min = 0.0). Additionally the order of torch.pow and torch.abs is incorrect, following the formula you must first calculate torch.pow

Another point is that to avoid gradient explosion, an episolon must be added in output and target_mag.