asappresearch / sru

Training RNNs as Fast as CNNs (https://arxiv.org/abs/1709.02755)
MIT License
2.1k stars 306 forks source link

Question about the SRU training speed in tensorflow #21

Closed johnnykthink closed 7 years ago

johnnykthink commented 7 years ago

Hi Tao,

Thanks for the great job, I had the implement you paper in tensorflow. In a seq2seq task, the SRU training speed is 1.6x faster than the tensorflow's BasicLSTMCell ! And the accuracy is a little better than LSTM.

But how can I get the 5-10x faster in your paper? Now I'm using feed_dict feed the data to model, I will use tfrecords later for compare.

Thanks.

taolei87 commented 7 years ago

hi, our speed optimized SRU has its own forward and backward implementations.

I haven't used tensorflow but after googling a bit it seems possible to write a custom op in tensorflow as well: https://www.tensorflow.org/extend/adding_an_op#implement_the_kernel_for_the_op It may need a bit of coding though..

thank you for trying SRU in tf. I would be happy to share your repo on our page if you would like to :)

johnnykthink commented 7 years ago

@taolei87 yes, I create a repo: https://github.com/johnnykthink/SRU-Tensorflow. I will try add batch normalization to it and create custom op.

Thanks for you job, you model is simple but effective ;-)

taolei87 commented 7 years ago

hi @johnnykthink i read your code. the matrix multiplication is performed in every recurrence step. you can potentially optimize it by using a single matrix multiplication for all tokens in the sequence [x_1, ..., x_n]. I've done similar optimization in theano before, it gives significant speed-up.

johnnykthink commented 7 years ago

@taolei87 Hi, it's a good idea. But if you use tensorflow builtin RnnCell, dynamic_rnn or seq2seq model, for the module reusability, you only can rewrite the cell implement in one time step.

Except to rewriting the entire RNN in tensorflow, I can't find any example show how to do this.

byzhang commented 6 years ago

@taolei87 Do you have some performance optimization suggestions on https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/rnn/python/ops/rnn_cell.py#L2718