LukasZahradnik / PyNeuraLogic

PyNeuraLogic lets you use Python to create Differentiable Logic Programs
https://pyneuralogic.readthedocs.io/
MIT License
281 stars 18 forks source link

[🐛 Bug Report]: saving `state_dict` with torch or pickle causes an error #57

Open DillonZChen opened 6 months ago

DillonZChen commented 6 months ago

Describe the bug

As from the title, saving a state_dict with torch causes an error

_pickle.PicklingError: Can't pickle <java class 'java.lang.String'>: attribute lookup java.lang.String on jpype._jstring failed

Steps to reproduce the behavior

Assuming torch is loaded and we have a (trained) model initialised by model = template.build(settings), the error occurs with the following code

state_dict = model.state_dict()
torch.save(state_dict, save_file)

Expected behavior

The code does not cause a pickling error.

Environment

No response

Additional context

I have an easy workaround which does not seem to cause any issues by converting the strings into Python strings

state_dict = model.state_dict()
state_dict["weight_names"] = {k: str(v) for k, v in state_dict["weight_names"].items()}
torch.save(state_dict, save_file)

I don't seem to have any unexpected behaviour while doing this, e.g. loaded models have the same predictions.

LukasZahradnik commented 6 months ago

Hi, thanks for opening an issue. I've prepared a fix (which works basically the same way as your workaround), and it will be released once have more changes implemented (hopefully soon). Will keep the issue open till then.

LukasZahradnik commented 6 months ago

Hi, this should be fixed in the latest version (0.7.14). Could you please confirm it works for you now? Thanks

DillonZChen commented 6 months ago

Hi Lukas,

I can confirm it works perfectly now. Thanks