Open devstermarts opened 1 year ago
Adding some of my speculation here until someone else replies: I think the problem is because of the different batch_size
in the training and export scripts. Setting the --batch_size 64
on the export script (since batch size is 64 by default during training) seemed to get rid of the error. Although I'm not sure if this is the right way to export the model. I am also unsure why this error didn't show up with the decoder_only
configuration.
Thanks @snnithya for looking into this. Using --batch_size 64 did the trick on msprior export for now.
Hey @caillonantoine i'm running into the following error on msprior export:
streaming mode is set to True Traceback (most recent call last): File "/content/miniconda/bin/msprior", line 8, in <module> sys.exit(main()) File "/content/miniconda/lib/python3.9/site-packages/msprior_scripts/main_cli.py", line 28, in main app.run(module.main) File "/content/miniconda/lib/python3.9/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/content/miniconda/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/content/miniconda/lib/python3.9/site-packages/msprior_scripts/export.py", line 16, in main model = ScriptedPrior( File "/content/miniconda/lib/python3.9/site-packages/msprior/scripted.py", line 53, in __init__ model.load_state_dict(ckpt, strict=False) File "/content/miniconda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for Prior: size mismatch for decoder.net.0._state: copying a param with shape torch.Size([8, 64, 512]) from checkpoint, the shape in current model is torch.Size([8, 1, 512]).
msprior version is 1.1.2 Training has been done with --config recurrent.