Qinwen-Hu / dparn

67 stars 12 forks source link

Runtime error while inferencing with pre trained model #5

Closed yugeshav closed 2 years ago

yugeshav commented 2 years ago

Hello @Qinwen-Hu

Amazing work! Iam getting this error when i tried evaluating pretrained model, can you please help!!

RuntimeError: Error(s) in loading state_dict for DPModel: Unexpected key(s) in state_dict: "process_model.intra_mha_list.0.MHA.out_proj.bias", "process_model.intra_mha_list.1.MHA.out_proj.bias".

Regards Yugesh

Qinwen-Hu commented 2 years ago

Thx for your interest in our work! I tried to run my inference code, but I didn't encounter this problem. I think it's probably due to the version of PyTorch? The MHA module is an off-the-peg module provided in torch.nn., and this module varies a little bit with different PyTorch verision. The code in this repository should run well in PyTorch 1.7.1. Hope it solves your problem.

koerthawkins commented 2 years ago

@yugeshav I was having the same issue using PyTorch 1.10.1. Here is an easy fix, you just need to insert it into Infer.py:

    # read state dict into variable
    state_dict: dict = checkpoint_DPARN["state_dict"]

    # remove problematic weights from state dict
    state_dict.pop("process_model.intra_mha_list.0.MHA.out_proj.bias")
    state_dict.pop("process_model.intra_mha_list.1.MHA.out_proj.bias")

    # load weights into model
    model.load_state_dict(checkpoint_DPARN['state_dict'])

Also, make sure to import and use signal_processing.py:iSTFT_module_1_8 instead of signal_processing.py:iSTFT_module_1_7 in Infer.py.

yugeshav commented 2 years ago

Also, make sure to import and use signal_processing.py:iSTFT_module_1_8 instead of signal_processing.py:iSTFT_module_1_7 in Infer.py.

@koerthawkins Thanks for the suggestion! It worked

koerthawkins commented 2 years ago

Also, make sure to import and use signal_processing.py:iSTFT_module_1_8 instead of signal_processing.py:iSTFT_module_1_7 in Infer.py.

@koerthawkins Thanks for the suggestion! It worked

Glad to have helped!

ndisci commented 2 years ago

@yugeshav I was having the same issue using PyTorch 1.10.1. Here is an easy fix, you just need to insert it into Infer.py:

    # read state dict into variable
    state_dict: dict = checkpoint_DPARN["state_dict"]

    # remove problematic weights from state dict
    state_dict.pop("process_model.intra_mha_list.0.MHA.out_proj.bias")
    state_dict.pop("process_model.intra_mha_list.1.MHA.out_proj.bias")

    # load weights into model
    model.load_state_dict(checkpoint_DPARN['state_dict'])

Also, make sure to import and use signal_processing.py:iSTFT_module_1_8 instead of signal_processing.py:iSTFT_module_1_7 in Infer.py.

I did what you say and it worked. I am using pretrained model. Pythorch version is 1.10.1. But the result is not good. ( I listened enhanced audio files) and also tested on Valentini test data. The data's pesq is about 2.293.

koerthawkins commented 2 years ago

@yugeshav I was having the same issue using PyTorch 1.10.1. Here is an easy fix, you just need to insert it into Infer.py:

    # read state dict into variable
    state_dict: dict = checkpoint_DPARN["state_dict"]

    # remove problematic weights from state dict
    state_dict.pop("process_model.intra_mha_list.0.MHA.out_proj.bias")
    state_dict.pop("process_model.intra_mha_list.1.MHA.out_proj.bias")

    # load weights into model
    model.load_state_dict(checkpoint_DPARN['state_dict'])

Also, make sure to import and use signal_processing.py:iSTFT_module_1_8 instead of signal_processing.py:iSTFT_module_1_7 in Infer.py.

I did what you say and it worked. I am using pretrained model. Pythorch version is 1.10.1. But the result is not good. ( I listened enhanced audio files) and also tested on Valentini test data. The data's pesq is about 2.293.

Yes, unfortunately the pre-trained checkpoint isn't good. I do not understand how it could reach the performance indicated in the paper.

What's more, the author hasn't responded to my pull request for one and a half months now. Therefore I'm not really positive that we'll get a better pre-trained checkpoint.

ndisci commented 2 years ago

@yugeshav I was having the same issue using PyTorch 1.10.1. Here is an easy fix, you just need to insert it into Infer.py:

    # read state dict into variable
    state_dict: dict = checkpoint_DPARN["state_dict"]

    # remove problematic weights from state dict
    state_dict.pop("process_model.intra_mha_list.0.MHA.out_proj.bias")
    state_dict.pop("process_model.intra_mha_list.1.MHA.out_proj.bias")

    # load weights into model
    model.load_state_dict(checkpoint_DPARN['state_dict'])

Also, make sure to import and use signal_processing.py:iSTFT_module_1_8 instead of signal_processing.py:iSTFT_module_1_7 in Infer.py.

I did what you say and it worked. I am using pretrained model. Pythorch version is 1.10.1. But the result is not good. ( I listened enhanced audio files) and also tested on Valentini test data. The data's pesq is about 2.293.

Yes, unfortunately the pre-trained checkpoint isn't good. I do not understand how it could reach the performance indicated in the paper.

What's more, the author hasn't responded to my pull request for one and a half months now. Therefore I'm not really positive that we'll get a better pre-trained checkpoint.

I got it. Thanks for your answer. Did you try to train a model with the train data ?

koerthawkins commented 2 years ago

I did what you say and it worked. I am using pretrained model. Pythorch version is 1.10.1. But the result is not good. ( I listened enhanced audio files) and also tested on Valentini test data. The data's pesq is about 2.293.

Yes, unfortunately the pre-trained checkpoint isn't good. I do not understand how it could reach the performance indicated in the paper. What's more, the author hasn't responded to my pull request for one and a half months now. Therefore I'm not really positive that we'll get a better pre-trained checkpoint.

I got it. Thanks for your answer. Did you try to train a model with the train data ?

No, I implemented my own training training loop and dataloader. That worked semi-good. Basically the model learned and trained, but the results weren't good in comparison to other denoising models, e.g. SkipConvNet. And the VRAM usage is immense due to the MHA-layers.

Qinwen-Hu commented 2 years ago

I did what you say and it worked. I am using pretrained model. Pythorch version is 1.10.1. But the result is not good. ( I listened enhanced audio files) and also tested on Valentini test data. The data's pesq is about 2.293.

Yes, unfortunately the pre-trained checkpoint isn't good. I do not understand how it could reach the performance indicated in the paper. What's more, the author hasn't responded to my pull request for one and a half months now. Therefore I'm not really positive that we'll get a better pre-trained checkpoint.

I got it. Thanks for your answer. Did you try to train a model with the train data ?

No, I implemented my own training training loop and dataloader. That worked semi-good. Basically the model learned and trained, but the results weren't good in comparison to other denoising models, e.g. SkipConvNet. And the VRAM usage is immense due to the MHA-layers.

Hi, I'm so sorry that I didn't reply to your pull request in time, I was not very familiar with Github (Didn't do much more than just reading other people's codes here) and I didn't notice it. Really thanks for fixing the bug. I tried the provided pretrained model on the Valentini dataset ang got 2.98 on the PESQ score. I'm not so sure what caused the performance degradation. There's indeed a performance gap between SCM-DPARN and SOTA 48k SE models (e.g. MTFAA-Net), and VRAM usage is indeed a big problem for dual path models with MHA-layers. We are trying to integrate our techniques with more reasonable structures on the full-band SE task.

koerthawkins commented 2 years ago

Yes, unfortunately the pre-trained checkpoint isn't good. I do not understand how it could reach the performance indicated in the paper. What's more, the author hasn't responded to my pull request for one and a half months now. Therefore I'm not really positive that we'll get a better pre-trained checkpoint.

I got it. Thanks for your answer. Did you try to train a model with the train data ?

No, I implemented my own training training loop and dataloader. That worked semi-good. Basically the model learned and trained, but the results weren't good in comparison to other denoising models, e.g. SkipConvNet. And the VRAM usage is immense due to the MHA-layers.

Hi, I'm so sorry that I didn't reply to your pull request in time, I was not very familiar with Github (Didn't do much more than just reading other people's codes here) and I didn't notice it. Really thanks for fixing the bug. I tried the provided pretrained model on the Valentini dataset ang got 2.98 on the PESQ score. I'm not so sure what caused the performance degradation. There's indeed a performance gap between SCM-DPARN and SOTA 48k SE models (e.g. MTFAA-Net), and VRAM usage is indeed a big problem for dual path models with MHA-layers. We are trying to integrate our techniques with more reasonable structures on the full-band SE task.

Hi Qinwen,

no worries! You're welcome.

If you write of more reasonable structure do you mean something to replace the MHA-layers?

Qinwen-Hu commented 2 years ago

Yes, unfortunately the pre-trained checkpoint isn't good. I do not understand how it could reach the performance indicated in the paper. What's more, the author hasn't responded to my pull request for one and a half months now. Therefore I'm not really positive that we'll get a better pre-trained checkpoint.

I got it. Thanks for your answer. Did you try to train a model with the train data ?

No, I implemented my own training training loop and dataloader. That worked semi-good. Basically the model learned and trained, but the results weren't good in comparison to other denoising models, e.g. SkipConvNet. And the VRAM usage is immense due to the MHA-layers.

Hi, I'm so sorry that I didn't reply to your pull request in time, I was not very familiar with Github (Didn't do much more than just reading other people's codes here) and I didn't notice it. Really thanks for fixing the bug. I tried the provided pretrained model on the Valentini dataset ang got 2.98 on the PESQ score. I'm not so sure what caused the performance degradation. There's indeed a performance gap between SCM-DPARN and SOTA 48k SE models (e.g. MTFAA-Net), and VRAM usage is indeed a big problem for dual path models with MHA-layers. We are trying to integrate our techniques with more reasonable structures on the full-band SE task.

Hi Qinwen,

no worries! You're welcome.

If you write of more reasonable structure do you mean something to replace the MHA-layers?

Not exactly, we will try to combine the spectral mapping part with some more different models(e.g. complex networks). This work uses MHA-layers to replace the RNNs in the dual path model mainly because we want to start from our previous work (DPCRN).

ndisci commented 2 years ago

I did what you say and it worked. I am using pretrained model. Pythorch version is 1.10.1. But the result is not good. ( I listened enhanced audio files) and also tested on Valentini test data. The data's pesq is about 2.293.

Yes, unfortunately the pre-trained checkpoint isn't good. I do not understand how it could reach the performance indicated in the paper. What's more, the author hasn't responded to my pull request for one and a half months now. Therefore I'm not really positive that we'll get a better pre-trained checkpoint.

I got it. Thanks for your answer. Did you try to train a model with the train data ?

No, I implemented my own training training loop and dataloader. That worked semi-good. Basically the model learned and trained, but the results weren't good in comparison to other denoising models, e.g. SkipConvNet. And the VRAM usage is immense due to the MHA-layers.

You should try with PyTorch 1.7.1+cu11.0. I got 3.00 pesq score using Valentini test set.