ankane / torch.rb

Deep learning for Ruby, powered by LibTorch
Other
704 stars 30 forks source link

Loading/saving modules with buffers #18

Closed orlando-labs closed 3 years ago

orlando-labs commented 3 years ago

PyTorch saves buffers (running vars and means) along with parameters in the state dicts. This PR addresses this issue and allows to correctly persist/restore models with batch normalization.

ankane commented 3 years ago

Hey @orlando-labs, thanks for another PR! Overall, looks good. For load_state_dict, I think it'd be better to move to the approach PyTorch uses of recursing child modules instead of recursing state dict keys (if that makes sense).

Also, it looks like some trailing whitespace is making it into PRs.

ankane commented 3 years ago

fwiw, I added some more tests at the end of test/nn/module_test.rb for missing and unexpected keys. I'm happy to help with the change mentioned above, so just let me know if that's wanted.

orlando-labs commented 3 years ago

Hi, @ankane, sorry for such a long absence. The approach you mentioned is, overall, right as it matches PyTorch behavior. So it's up to you. But I think it's not necessary, at least for now.

ankane commented 3 years ago

Hey @orlando-labs, can you give master a shot?

orlando-labs commented 3 years ago

Sure! Done and merged.

ankane commented 3 years ago

Hey @orlando-labs, that test case should be on master (part of c39c487aa46ebd8b20cdc64303a1b61ee2469c87). If the code works for your models, I think we can close this out.

orlando-labs commented 3 years ago

@ankane, practically, I see no issues with all my models within my train/inference pipelines.

ankane commented 3 years ago

Awesome, thanks for all the work on this!