ryanleary / mlperf-rnnt-ref

Other
3 stars 1 forks source link

Explore JIT for projected lstm #3

Open samgd opened 4 years ago

samgd commented 4 years ago

https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/

samgd commented 4 years ago

@mwawrzos I copied across PyTorch's LSTM example custom RNN code:

https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py

https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/

Pushed to this branch (it's quite messy, sorry!):

https://github.com/ryanleary/mlperf-rnnt-ref/tree/integration_myrtle_ln

(The other main find/fix included in this push was that we had one too many outputs. e.g. len(ctc_vocab) rather than len(ctc_vocab) - 1.)

This has been running for a day but it is slow.

One issue for memory/training iteration time is that NVIDIA's amp/apex library for mixed precision does not appear to play nicely with PyTorch's jit. Can we contact anyone to help us with this? Is there even a workaround?

WER evaluated on dev-clean every 1/3 epoch:

==========>>>>>>Evaluation WER: 0.5119297084666005
==========>>>>>>Evaluation WER: 0.3760339693393625
==========>>>>>>Evaluation WER: 0.30912466453439214
mwawrzos commented 4 years ago

Do you mean, that the execution is slow when JIT is in use?

mwawrzos commented 4 years ago

@samgd, what is the main purpose of this task? Can we skip it for now? I asked some people trying to use JIT, and I find out, that it is still not mature enough feature, and not always brings expected outcome.

samgd commented 4 years ago

Yeah, execution is slow due to the use of FP32. I'm investigating writing a ScriptModule variant that handles the casting etc and seeing whether it's possible to have amp ignore it. Did the amp/apex/jit discussion bring to light any tips/tricks?

The layer normalised LSTMs will be required for convergence hence it may not be something we can drop. Naive stacking seems to yield poor results (and none when the depth is > 6 in our initial experiments).

mwawrzos commented 4 years ago

The only tip I found was to cast to FP16 by hand like you are doing right now. The tip was for inference, where it was as simple as model.half(), but for training, it may not be enough.

mwawrzos commented 4 years ago

Let us try to run the code without JIT if JIT doesn't work. If the model will converge, we will decide, if we can replace Layer Norm with something else, or if JIT is the best solution.

I can run it multi-GPU already, I can try to run it multi-node also. Even if it is slow as hell, It may finish in a reasonable time when scaled enough.

samgd commented 4 years ago

jit does work and is currently underway in a slow manner (compared to cudnn LSTM). If jit and mixed precision are mutually exclusive we should work out which is the better option to use!

mwawrzos commented 4 years ago

To clarify, Layer Normalized LSTM with JIT and FP32 is faster than Layer Normalized LSTM with AMP/APEX?

samgd commented 4 years ago

Yes, I see these results:

amp: 2.21 s ± 14.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
script: 1.32 s ± 528 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

When timing (via %%timit) a forward pass with the following configuration:

seq_len = 1024
batch = 64
input_size = 1024
hidden_size = 1280
proj_size = 640
num_layers = 5

(warm-up done before timing starts in both cases)

mwawrzos commented 4 years ago

Ok, then I will work on scaling the training.

It would be handy to have the possibility to quickly test if some configuration works better than another before trying to optimize it.

That way we can check, what is the gain from Layer Norm and projection.

samgd commented 4 years ago

@mwawrzos @ryanleary

Pushed the update to switch between LSTM, LSTM + BN, LSTM + LN: https://github.com/ryanleary/mlperf-rnnt-ref/tree/integration_myrtle_ln

Untested on multiple-gpus so may required changes.

Can these configs be used to try and establish a baseline?

samgd commented 4 years ago

The "normal" LSTM configuration with depth 5 appears to be training (slightly modified config) on 1 RTX:

[rnnt]
rnn_type = "lstm"
encoder_n_hidden = 640
encoder_pre_rnn_layers = 2
encoder_stack_time_factor = 2
encoder_post_rnn_layers = 3
pred_n_hidden = 640
pred_rnn_layers = 2
forget_gate_bias = 1.0
joint_n_hidden = 512
==========>>>>>>Evaluation WER: 0.6945516708944524
==========>>>>>>Evaluation WER: 0.5440608801147017
==========>>>>>>Evaluation WER: 0.4745965221866843
==========>>>>>>Evaluation WER: 0.4367302672695857
==========>>>>>>Evaluation WER: 0.40924965993897283
==========>>>>>>Evaluation WER: 0.38812911290026103
==========>>>>>>Evaluation WER: 0.37811109885665967
==========>>>>>>Evaluation WER: 0.3633689937869931
==========>>>>>>Evaluation WER: 0.34412337781699204
==========>>>>>>Evaluation WER: 0.347009301128635
==========>>>>>>Evaluation WER: 0.34223006507113707