Closed orlando-labs closed 3 years ago
Hey @orlando-labs, thanks for another PR! Overall, looks good. For load_state_dict
, I think it'd be better to move to the approach PyTorch uses of recursing child modules instead of recursing state dict keys (if that makes sense).
Also, it looks like some trailing whitespace is making it into PRs.
fwiw, I added some more tests at the end of test/nn/module_test.rb
for missing and unexpected keys. I'm happy to help with the change mentioned above, so just let me know if that's wanted.
Hi, @ankane, sorry for such a long absence. The approach you mentioned is, overall, right as it matches PyTorch behavior. So it's up to you. But I think it's not necessary, at least for now.
Hey @orlando-labs, can you give master a shot?
Sure! Done and merged.
Hey @orlando-labs, that test case should be on master (part of c39c487aa46ebd8b20cdc64303a1b61ee2469c87). If the code works for your models, I think we can close this out.
@ankane, practically, I see no issues with all my models within my train/inference pipelines.
Awesome, thanks for all the work on this!
PyTorch saves buffers (running vars and means) along with parameters in the state dicts. This PR addresses this issue and allows to correctly persist/restore models with batch normalization.