kashif / pytorch-transformer-ts

Repository of Transformer based PyTorch Time Series Models
MIT License
289 stars 41 forks source link

Feedback welcome on xformers #1

Open blefaudeux opened 2 years ago

blefaudeux commented 2 years ago

hey there, trying to investigate some strange downloads numbers for xformers I stumbled here, and had a look at https://github.com/kashif/pytorch-transformer-ts/blob/main/xformers/xformers.ipynb

From what I can see an attention mask was passed, and we've tried to homogenize this to an additive mask (ie. values which you want to nuke are "-inf", since this is applied pre-softmax). One benefit is that other encodings can be passed as additive masks, for instance if you want to emphasize local attention (a la Alibi).

Feel free to reach out if something looks suspicious, xformers probably lacks some explanations here and there, any feedback welcome. If mail is easier benjamin lefaudeux $ pm me would work. Cheers

kashif commented 2 years ago

oh thank you so much! let me have a look and get back to you!! I truly appreciate your help!

kashif commented 2 years ago

@blefaudeux one issue I have is that If I set the encoder of my transformer to be an xformer encoder block and then use the standard decoder in the nn.Transformer, then such a model doesn't seem to train or learn anything... e.g. see the notebook here: https://github.com/kashif/pytorch-transformer-ts/blob/main/xformers/xformers.ipynb for how i am doing that... would you have any intuition on why this is the case?

Thanks!!

blefaudeux commented 2 years ago

hmm, checking the code right now but it's often a sign that the graph is broken, in that autograd cannot walk back the chain from the final loss to the inputs, it can happen when you change variables in the middle, do some in place operations (there's a guard for that and it normally asserts), or mess up a transform in the middle which make things uniformly random.

When you say that it does not train, the loss changes but does not improve, right ? If that's the case I would go for the third point, the graph is not broken but some operation in the middle randomizes it

blefaudeux commented 2 years ago

to give you an example with pictures: if you do a reshape of a batched tensor in the middle of a model, and by mistake mix the contents from all the pictures in doing so (it can happen relatively easily given some reshape assumptions), in that case there would be nothing to learn from the end of the pipe, the data is randomized really

kashif commented 2 years ago

right it could be some shuffling going on... I will check what the format of the memory output from the encoder in xformers is vs. in the vanilla pytorch encoders... thanks!

blefaudeux commented 2 years ago

right it could be some shuffling going on... I will check what the format of the memory output from the encoder in xformers is vs. in the vanilla pytorch encoders... thanks!

ahh, it makes me think, there's an option when constructing the pytorch transformers to say that you are "batch first", maybe that's because of that. xFormers follows [Batch x Context x Embedding] everywhere

kashif commented 2 years ago

right that is taken care of so i make nn.transformers batch_first so then internally it transposes the output from the xformer encoder via transpose(0, 1) and then passes that to the MHA function etc.

If I use the xformer decoder the model trains etc. but if I use the MHA, pytorch decoder it doesn't seem to train...

the MHA i believe is implemented in C++ on the pytorch side so I might try a python version of that to see...

kashif commented 2 years ago

@blefaudeux I believe I got it to train and with nystrom etc. it works... however with some other attention heads e.g. random and others I do get errors like:

RuntimeError: Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd.

should I open an issue for them on the xformer side?

kashif commented 2 years ago

@blefaudeux ok so i figured out the issue with the Inference tensors (above) occurs since in my validation step I am using with torch.inference_mode() and if I change it to torch.no_grad() I do not get the RuntimeError...

blefaudeux commented 2 years ago

oh great ! sorry for the delay, I saw your message but was not sure about that one, I'm glad that you found out what was happening.. I didn't know of this torch.inference_mode(), I'll look that up