caillonantoine / msprior

155 stars 10 forks source link

Error on export recurrent model (torch.Size shape mismatch) #5

Open devstermarts opened 1 year ago

devstermarts commented 1 year ago

Hey @caillonantoine i'm running into the following error on msprior export:

streaming mode is set to True Traceback (most recent call last): File "/content/miniconda/bin/msprior", line 8, in <module> sys.exit(main()) File "/content/miniconda/lib/python3.9/site-packages/msprior_scripts/main_cli.py", line 28, in main app.run(module.main) File "/content/miniconda/lib/python3.9/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/content/miniconda/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/content/miniconda/lib/python3.9/site-packages/msprior_scripts/export.py", line 16, in main model = ScriptedPrior( File "/content/miniconda/lib/python3.9/site-packages/msprior/scripted.py", line 53, in __init__ model.load_state_dict(ckpt, strict=False) File "/content/miniconda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for Prior: size mismatch for decoder.net.0._state: copying a param with shape torch.Size([8, 64, 512]) from checkpoint, the shape in current model is torch.Size([8, 1, 512]).

msprior version is 1.1.2 Training has been done with --config recurrent.

snnithya commented 1 year ago

Adding some of my speculation here until someone else replies: I think the problem is because of the different batch_size in the training and export scripts. Setting the --batch_size 64 on the export script (since batch size is 64 by default during training) seemed to get rid of the error. Although I'm not sure if this is the right way to export the model. I am also unsure why this error didn't show up with the decoder_only configuration.

devstermarts commented 1 year ago

Thanks @snnithya for looking into this. Using --batch_size 64 did the trick on msprior export for now.