asteroid-team / asteroid

The PyTorch-based audio source separation toolkit for researchers
https://asteroid-team.github.io/
MIT License
2.23k stars 421 forks source link

Loading DCU pretrained model (from JorisCos / DCUNet_Libri1Mix_enhsingle_16k ) #450

Closed manu0586 closed 3 years ago

manu0586 commented 3 years ago

asteroid version : 0.4.4

function call : model = DCUNet.from_pretrained("JorisCos/DCUNet_Libri1Mix_enhsingle_16k") Error :

RuntimeError: Error(s) in loading state_dict for DCUNet:
    size mismatch for masker.decoders.7.deconv.re_module.weight: copying a param with shape torch.Size([180, 45, 7, 5]) from checkpoint, the shape in current model is torch.Size([180, 90, 7, 5]).
    size mismatch for masker.decoders.7.deconv.im_module.weight: copying a param with shape torch.Size([180, 45, 7, 5]) from checkpoint, the shape in current model is torch.Size([180, 90, 7, 5]).
    size mismatch for masker.decoders.7.norm.re_module.weight: copying a param with shape torch.Size([45]) from checkpoint, the shape in current model is torch.Size([90]).
    size mismatch for masker.decoders.7.norm.re_module.bias: copying a param with shape torch.Size([45]) from checkpoint, the shape in current model is torch.Size([90]).
    size mismatch for masker.decoders.7.norm.re_module.running_mean: copying a param with shape torch.Size([45]) from checkpoint, the shape in current model is torch.Size([90]).
    size mismatch for masker.decoders.7.norm.re_module.running_var: copying a param with shape torch.Size([45]) from checkpoint, the shape in current model is torch.Size([90]).
    size mismatch for masker.decoders.7.norm.im_module.weight: copying a param with shape torch.Size([45]) from checkpoint, the shape in current model is torch.Size([90]).
    size mismatch for masker.decoders.7.norm.im_module.bias: copying a param with shape torch.Size([45]) from checkpoint, the shape in current model is torch.Size([90]).
    size mismatch for masker.decoders.7.norm.im_module.running_mean: copying a param with shape torch.Size([45]) from checkpoint, the shape in current model is torch.Size([90]).
    size mismatch for masker.decoders.7.norm.im_module.running_var: copying a param with shape torch.Size([45]) from checkpoint, the shape in current model is torch.Size([90]).
    size mismatch for masker.decoders.8.deconv.re_module.weight: copying a param with shape torch.Size([90, 45, 1, 7]) from checkpoint, the shape in current model is torch.Size([135, 90, 1, 7]).
    size mismatch for masker.decoders.8.deconv.im_module.weight: copying a param with shape torch.Size([90, 45, 1, 7]) from checkpoint, the shape in current model is torch.Size([135, 90, 1, 7]).
    size mismatch for masker.decoders.8.norm.re_module.weight: copying a param with shape torch.Size([45]) from checkpoint, the shape in current model is torch.Size([90]).
    size mismatch for masker.decoders.8.norm.re_module.bias: copying a param with shape torch.Size([45]) from checkpoint, the shape in current model is torch.Size([90]).
    size mismatch for masker.decoders.8.norm.re_module.running_mean: copying a param with shape torch.Size([45]) from checkpoint, the shape in current model is torch.Size([90]).
    size mismatch for masker.decoders.8.norm.re_module.running_var: copying a param with shape torch.Size([45]) from checkpoint, the shape in current model is torch.Size([90]).
    size mismatch for masker.decoders.8.norm.im_module.weight: copying a param with shape torch.Size([45]) from checkpoint, the shape in current model is torch.Size([90]).
    size mismatch for masker.decoders.8.norm.im_module.bias: copying a param with shape torch.Size([45]) from checkpoint, the shape in current model is torch.Size([90]).
    size mismatch for masker.decoders.8.norm.im_module.running_mean: copying a param with shape torch.Size([45]) from checkpoint, the shape in current model is torch.Size([90]).
    size mismatch for masker.decoders.8.norm.im_module.running_var: copying a param with shape torch.Size([45]) from checkpoint, the shape in current model is torch.Size([90]).
    size mismatch for masker.output_layer.0.re_module.weight: copying a param with shape torch.Size([90, 1, 7, 1]) from checkpoint, the shape in current model is torch.Size([135, 1, 7, 1]).
    size mismatch for masker.output_layer.0.im_module.weight: copying a param with shape torch.Size([90, 1, 7, 1]) from checkpoint, the shape in current model is torch.Size([135, 1, 7, 1]).

information from cached_model (via cached_download function + torch.load) model_args= {'architecture': 'Large-DCUNet-20', 'stft_kernel_size': 1024, 'stft_stride': 256, 'sample_rate': 16000.0, 'fix_length_mode': 'pad', 'n_src': 1}

JorisCos commented 3 years ago

Hello thanks for reporting, I took down the model for now and will upload a corrected version.

realkris commented 3 years ago

Hi, has this issue solved? C:

JorisCos commented 3 years ago

Not yet, I'm training the corrected model and will release it when its ready. I will post the link here and close the issue when it's done

JorisCos commented 3 years ago

Training a model like this on LibriMix takes 5-7 days. Sorry for the inconvenience

mpariente commented 3 years ago

Can we explain what was the problem for loading the previous one?

Le ven. 26 févr. 2021 à 12:19, JorisCos notifications@github.com a écrit :

Training a model like this on LibriMix takes 5-7 days. Sorry for the inconvenience

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/asteroid-team/asteroid/issues/450#issuecomment-786585978, or unsubscribe https://github.com/notifications/unsubscribe-auth/AEND2HEXUDEV4DHAIWAGMATTA57VDANCNFSM4YCVHD6A .

JorisCos commented 3 years ago

The first Large-DCUNet-20 architecture wasn't implemented correctly we corrected it. I trained on a different Large-DCUNet-20 architecture before we corrected the current version. The correction ended up in small size changes in the last 2 decoders layers compared to the architecture that I used. This why we can't load the models size mismatch

JorisCos commented 3 years ago

The model is fixed and now available here. Thanks again for reporting the issue.