stickeritis / sticker

Succeeded by SyntaxDot: https://github.com/tensordot/syntaxdot
Other
25 stars 2 forks source link

Switch to the Keras LSTM/GRU implementation #166

Closed danieldk closed 4 years ago

danieldk commented 4 years ago

Recent versions of Tensorflow Keras will automatically switch between cuDNN and Tensorflow implementations. The trained parameters work regardless of the selected implementation.

The conditions for using the cuDNN implementation are documented at:

https://www.tensorflow.org/api_docs/python/tf/keras/layers/LSTM

They boil down to: 1. a NVIDIA GPU is available, 2. certain hyper parameters (e.g. activations) are set to specific values. If the cuDNN implementation is selected, this results in a nice speedup.

The Tensorflow requirements are bumped to 1.15.0. This setup fails with 1.14.0 with a constant folding error in Grappler:

https://github.com/tensorflow/tensorflow/issues/29525

danieldk commented 4 years ago

I investigated this a bit after the discussion during the A3 meeting last Thursday. It turns out that we do not even need one of the 3 or 4 articles (I forgot) about sharing parameters between the cuDNN/Tensorflow implementation for GPU vs CPU. The Keras LSTM/GRU layers wrap around these and select an implementation on the fly. It harmonizes the default hyper parameters and does the extra bias juggling imposed by the cuDNN implementation for you.

Made this a draft PR, because I actually need to test it more. A simple RNN parser seemed to train well, but I need to train some other models, verify that CPU prediction gives the same results, and check that there are no (large) regressions in CPU performance.

twuebi commented 4 years ago

How does the Keras implementation deal with sequence lengths? At which index does the backward pass start for each row?

danieldk commented 4 years ago

How does the Keras implementation deal with sequence lengths? At which index does the backward pass start for each row?

The length is computed from the mask (if provided). However, reading the documentation I concluded that masking is not supported, since one of the stated requirements of selecting the cudnn layer is:

Inputs are not masked or strictly right padded.

However, looking at the implementation, masking should be supported for the cudnn layer:

https://github.com/tensorflow/tensorflow/blob/r2.0/tensorflow/python/keras/layers/recurrent_v2.py#L951

So, there is an ambiguity in the doc, the interpretation is:

Inputs are (not masked) or strictly right padded.

and not my initial read:

Inputs are not (masked or strictly right padded).

I'll update the PR accordingly.

danieldk commented 4 years ago

Adding my e-mail here, so that this is documented:

Bidirectional RNNs do not work once masking is enabled.

My current theory is that there is a mismatch between the cuDNN layer API and the rest of Keras API. When go_backwards is used, the states are supposed to be returned in reverse order, so the Bidirectional layer reverses them before concatenation. However, if I read the code directly, the cuDNN layer reverses the sequence after masking is applied:

https://github.com/tensorflow/tensorflow/blob/6d7211299d878c4cdb499479eb6dd460b22e994f/tensorflow/python/keras/layers/recurrent_v2.py#L1363

So, it seems to be doubly reversed. Whereas, when no masking is used, the sequence is only reversed before prediction:

https://github.com/tensorflow/tensorflow/blob/6d7211299d878c4cdb499479eb6dd460b22e994f/tensorflow/python/keras/layers/recurrent_v2.py#L1371

It's just a working theory, but it explains the outcome. With the cuDNN layer:

  1. Without masking, Bidirectional works. The sequence is not reversed by the LSTM, but by the Bidirectional layer. The correct states are concatenated.
  2. With masking, Bidirectional does not work. The sequence is reversed by the LSTM (to get the original order), but reversed again by Bidirectional. So the last state of the backwards direction is concatenated with the first state of the forward direction, etc.
  3. With masking, a simple concat works. The sequence is reversed by the LSTM but not reversed again.

Then (3) breaks with CPU prediction, because it falls back on the implementation that returns states in the reverse order and the wrong states get concatenated again.

danieldk commented 4 years ago

I have now vendored the Keras recurrent module in this PR in the first commit. In the second commit, I removed the reverses in the cuDNN LSTM/GRU functions. I have only trained for a few epochs now, but the scores look good and when evaluating on CPU, the accuracy is (virtually) the same.

danieldk commented 4 years ago

@twuebi @DiveFish I have put this out of the draft state. I have trained several models and the accuracies are on par with what they were before. The performance (in sentences per second) in CPU prediction is also the same.

danieldk commented 4 years ago

Two extra notes:

  1. We should really file a bug upstream for the problem with reversal in cudnn_lstm.
  2. I did not test GRUs yet.
twuebi commented 4 years ago

Did you test a byte-rnn?

danieldk commented 4 years ago

Did you test a byte-rnn?

Yes, I tested German NER, works as expected.

Oddly enough, Dutch UD does not work at all. Though all prior experiments were on turing-sfb and this one on hopper.

danieldk commented 4 years ago

Did you test a byte-rnn?

Yes, I tested German NER, works as expected.

Oddly enough, Dutch UD does not work at all. Though all prior experiments were on turing-sfb and this one on hopper.

Strange:

twuebi commented 4 years ago

cuda + tf versions match?

hopper: Keras RNN -> fail (stuck in 70ies accuracies)

isn't that approximately what we had with the reversing bug?

danieldk commented 4 years ago

cuda + tf versions match?

Both are Tensorflow 1.15.0 and both are the same CUDA 10.0 from nixpkgs. Driver versions differ, 410.48 (CUDA 10.0) on turing-sfb, 418.87.00 (CUDA 10.1) on hopper. It's not CuDNN, I have already bumped CuDNN from 7.5.0 to 7.6.5, no change. I will also use CUDA 10.1 from nixpkgs to see if that makes a change.

danieldk commented 4 years ago

isn't that approximately what we had with the reversing bug?

Thank you thank you thank you! That saved quite a bit of time.

Turns out that those lines were not staged when I amended the commit. So, they are missing in this PR (and thus on hopper, where I just did a git fetch && git rebase keras-rnn.

I'll update the PR.

danieldk commented 4 years ago

For the base model, a training epoch epoch now takes 38 seconds on hopper, 1:15 on turing.

danieldk commented 4 years ago

Nice speedup! What's the difference to before?

Hopper: 46 seconds per epoch. Turing: will have to do another run.

I did see very short epoch times (24 seconds) with Keras + auto mixed graph before fixing this bug. I don't get the same performance with the vendored Keras. But with vendored Keras I also get an error during validation (haven't investigated further) with mixed precision:

Cannot run graph: {inner:0x5602b200fef0, InvalidArgument: 2 root error(s) found.
  (0) Invalid argument: TensorArray dtype is float but Op is trying to write dtype half.
         [[{{node model/bidirectional/while/TensorArrayWrite/TensorArrayWriteV3}}]]
         [[model/bidirectional_1/while_1/Identity/_277]]
  (1) Invalid argument: TensorArray dtype is float but Op is trying to write dtype half.
         [[{{node model/bidirectional/while/TensorArrayWrite/TensorArrayWriteV3}}]]
0 successful operations.
0 derived errors ignored.

I'll investigate this a bit more.