cambridge-mlg / EinsumNetworks

50 stars 20 forks source link

EiNets cannot be saved to/loaded from disk when `use_em=False` #1

Open arranger1044 opened 4 years ago

arranger1044 commented 4 years ago

By setting use_em=False (e.g., to use SGD), the reparam function is created as a local function which cannot be readily pickled.

This means that torch.load and torch.save (which are using pickle) will throw exceptions like:

AttributeError: Can't pickle local object 'SumLayer.reparam_function.<locals>.reparam'

Using load_from_state seems not to be supported out of the box. To reproduce this behaviour, see this minimal working example https://github.com/arranger1044/EinsumNetworks-1/blob/master/test/test_load_save.py

lioutikov commented 3 years ago

Hi, Just to throw my two cents in. I don't think this is an issue that needs fixing in this project. Without looking too deep into the code I believe returning a function rather than just reparameterization itself could be a design choice, e.g., for modularity. If so, I believe it's unwise to force changes to this project because of the "shortcomings" of another package, i.e., pickles issue with lambda functions, if avoidable.

The issue that pickle doesn't deal well with anonymous/lambda/local functions is known and there are other ways to deal with it. A common way is to use dill instead of pickle. pytorch offers the 'pickle_module' argument for that exact reason. Hence, the script you linked is easily fixed using the following adaptations:

quick_fix.txt

arranger1044 commented 3 years ago

Thanks @lioutikov good point about using dill! Perhaps it could be nice to bake this into a wrapper to load/save einets without demanding users to remember and manually fix the pickle_module?

I am sceptical about the modularity however, as there is no apparent advantage (on that side or others) w.r.t. using polymorphism through inheritance. Am I missing something here?