salesforce / pytorch-qrnn

PyTorch implementation of the Quasi-Recurrent Neural Network - up to 16 times faster than NVIDIA's cuDNN LSTM
BSD 3-Clause "New" or "Revised" License
1.26k stars 193 forks source link

[WIP] Does this Cython translation look correct? #1

Open honnibal opened 6 years ago

honnibal commented 6 years ago

Hey,

I'd like to try this out in spaCy, so I'm porting to Thinc. First step is to understand it by implementing ^_^.

I translated the naming into my own conventions, which is admittedly obnoxious. I can translate it back once it's correct. For now, does this look like it's doing the right thing?

salesforce-cla[bot] commented 6 years ago

Thanks for the contribution! Before we can merge this, we need @honnibal to sign the Salesforce Contributor License Agreement.

Smerity commented 6 years ago

Excited to see what this is like Cythonized :)

I'm going to focus on the forward (recurrent_forget_mult) as it's saner for working out the indexing, which is sadly most of the hair pulling -_- lol

The easiest way to test this at scale is to potentially use it in place of the CPUForgetMult and then run the code to see if they reach the same conclusion.

honnibal commented 6 years ago

Hmm well I think the indexing is correct. It'll die rudely if it's not, so that'll be no problem. I mostly wanted to make sure I had the loop structure and the inner equations correct, because I don't speak CUDA very well.

Is the CNN fast in PyTorch? I've been working on some faster code for my CNN in spaCy:

The code for that is here: https://github.com/explosion/thinc/blob/feature/fast-mwe/thinc/neural/_fast_maxout_cnn.pyx

I'm getting about 2x tagger speed with this instead of my numpy-based implementation, so I'm not sure it will be worth the effort. The problem is that the speed strategy relies on turning off the threading for BLAS, and instead parallelising the loop over minibatch examples. If PyTorch makes it easy to make a C-level call to sgemm from Cython, it might be worth trying this out. Otherwise I think the normal CNN implementation seems sound.