Closed Parskatt closed 3 years ago
Using CPU instead of GPU gives me ~2 batches/s. It's not perfect, but its better.
@Parskatt thank you for looking into this! i noticed this as well, but didn't know CPU would be faster
https://github.com/lucidrains/performer-pytorch/releases/tag/0.0.9
Thanks for being super fast as usual :)
I think I will personally use trainable projection matrices, initialized as N(0,I). I'll let you know if it works out ;)
I'll close this issue
@Parskatt please do!
Doing the QR-decomposition:
https://github.com/lucidrains/performer-pytorch/blob/f9765c4ec1006f073c42d745588e78e3fb134537/performer_pytorch/performer_pytorch.py#L67-L70
Slows down training substantially (at least for batch sizes of ~4). For example, in my own experiments I get ~2.5 batches/s per GPU without redrawing, and ~1.4 batches/s with redrawing.
I found one solution from pytorch GP, which dispatches to CPU for small QR factorizations:
https://github.com/cornellius-gp/gpytorch/pull/1224
Perhaps a similar strategy could be used? I think num_cols should never really be more than about ~100 though, so perhaps you should always use cpu here?