nikitakit / self-attentive-parser

High-accuracy NLP parser with models for 11 languages.
https://parser.kitaev.io/
MIT License
861 stars 153 forks source link

Training script crashes on pytorch 1.2 #42

Closed Genius1237 closed 3 years ago

Genius1237 commented 5 years ago

The current version of the code does not run with pytorch 1.2, which is the current latest version. I am running the training script on the ptb data with --use-words as the only flag.

The error is in the call of FeatureDropoutFunction.apply(), in the line https://github.com/nikitakit/self-attentive-parser/blob/1ee43a8f93d6f3259c09ea1ff57cf5124ec32efc/src/parse_nk.py#L107 . output is of shape ([2016, 1024]) and ctx.noise is of shape ([1379, 1024]), due to which the mul operation fails.

Note that this does not happen in every call of FeatureDropoutFunction.apply(). While stepping through, this exception is seen in the second call only. In the first time it's called, both the dimensions match and there is no exception thrown.

With Pytorch 1.1, these errors do not seem to appear. In a trial run, output and ctx.noise are of shape (1413, 1024) and there is no problem.

I can provide further stack traces if needed.

Lijiachen1018 commented 4 years ago

I meet this error too with PyTorch 1.2.0, and the same, version 1.1 works.

nikitakit commented 4 years ago

For now the code only supports pytorch 1.1, and the older pytorch-pretrained-bert package (rather than the new version called transformers, which changes the API).

I think the pytorch 1.2 incompatibilities are caused by the introduction of the bool dtype for pytorch. Switching from uint8 masks to bool masks should fix most of the compatibility errors.

Lijiachen1018 commented 4 years ago

For now the code only supports pytorch 1.1, and the older pytorch-pretrained-bert package (rather than the new version called transformers, which changes the API).

I think the pytorch 1.2 incompatibilities are caused by the introduction of the bool dtype for pytorch. Switching from uint8 masks to bool masks should fix most of the compatibility errors.

Thank you!

yangky11 commented 4 years ago

I had the same issue and switching from uint8 to bool seems to work for me.

mengxj08 commented 4 years ago

Switching from torch.uint8 to torch.bool works for me (pytorch 1.2)

nikitakit commented 3 years ago

Benepar v0.2.0a has been updated to work with pytorch 1.6.