lucidrains / performer-pytorch

An implementation of Performer, a linear attention-based transformer, in Pytorch
MIT License
1.08k stars 141 forks source link

Redrawing normalized samples using QR slows down training #6

Closed Parskatt closed 3 years ago

Parskatt commented 3 years ago

Doing the QR-decomposition:

https://github.com/lucidrains/performer-pytorch/blob/f9765c4ec1006f073c42d745588e78e3fb134537/performer_pytorch/performer_pytorch.py#L67-L70

Slows down training substantially (at least for batch sizes of ~4). For example, in my own experiments I get ~2.5 batches/s per GPU without redrawing, and ~1.4 batches/s with redrawing.

I found one solution from pytorch GP, which dispatches to CPU for small QR factorizations:

https://github.com/cornellius-gp/gpytorch/pull/1224

Perhaps a similar strategy could be used? I think num_cols should never really be more than about ~100 though, so perhaps you should always use cpu here?

Parskatt commented 3 years ago

Using CPU instead of GPU gives me ~2 batches/s. It's not perfect, but its better.

lucidrains commented 3 years ago

@Parskatt thank you for looking into this! i noticed this as well, but didn't know CPU would be faster

https://github.com/lucidrains/performer-pytorch/releases/tag/0.0.9

Parskatt commented 3 years ago

Thanks for being super fast as usual :)

I think I will personally use trainable projection matrices, initialized as N(0,I). I'll let you know if it works out ;)

I'll close this issue

lucidrains commented 3 years ago

@Parskatt please do!