lucidrains / x-transformers

A concise but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.63k stars 395 forks source link

Could not call torch.save on the model #141

Closed frederikfab closed 1 year ago

frederikfab commented 1 year ago

When calling torch.save on the model, since implementing the fast attention, i got a pickling error:

----> 1 torch.save(x, 'test.pt')

File /usr/local/Caskroom/miniconda/base/envs/repo/lib/python3.8/site-packages/torch/serialization.py:441, in save(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization)
    439 if _use_new_zipfile_serialization:
    440     with _open_zipfile_writer(f) as opened_zipfile:
--> 441         _save(obj, opened_zipfile, pickle_module, pickle_protocol)
    442         return
    443 else:

File /usr/local/Caskroom/miniconda/base/envs/repo/lib/python3.8/site-packages/torch/serialization.py:653, in _save(obj, zip_file, pickle_module, pickle_protocol)
    651 pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
    652 pickler.persistent_id = persistent_id
--> 653 pickler.dump(obj)
    654 data_value = data_buf.getvalue()
    655 zip_file.write_record('data.pkl', data_value, len(data_value))

PicklingError: Can't pickle <class 'x_transformers.attend.EfficientAttentionConfig'>: attribute lookup EfficientAttentionConfig on x_transformers.attend failed

You can reproduce this with the following script

from x_transformers import TransformerWrapper, Decoder
x = TransformerWrapper(num_tokens=10, max_seq_len=10, attn_layers=Decoder(dim=10, depth=1, heads=1))
torch.save('test.pt', x)
lucidrains commented 1 year ago

oops, thank you!