google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.96k stars 2.75k forks source link

CTC Loss for Speech Recognition models #9350

Open patrickvonplaten opened 2 years ago

patrickvonplaten commented 2 years ago

Please:

Community Request for JAX CTC Loss

First of all, I'm not 100% sure whether this repo is the right place or whether we should rather open it in https://github.com/deepmind/optax or https://github.com/google/flax

Many state-of-the-art speech recognition models are fine-tuned using the CTC loss - most notably Wav2Vec2. Both PyTorch and Tensorflow have native support for the CTC loss - see:

People seem to start using TPUs more often and PyTorch's CTC loss is not made for torch/XLA, see: https://github.com/pytorch/xla/issues/2681#issuecomment-894468034 JAX has very good TPU support and is arguably more user-friendly for PyTorch users than TF.

Think this could also be a cool first/second issue.

Gently pinging @marcvanzee @jheek @avital here as well

jheek commented 2 years ago

It could also live as an objective in JaxOpt

patrickvonplaten commented 2 years ago

Cool, would do you guys prefer? Think we could find some community members to work on adding such a loss actually

patrickvonplaten commented 2 years ago

In case this is already somewhere developed internally and will soon be open-sourced it'd be great to hear about it :-)

hawkinsp commented 2 years ago

There's a CTC loss in Lingvo/JAX: https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/ctc_objectives.py

Could you use that one?

mattjj commented 2 years ago

We're also asking the optax folks what they think. I'll assign this to myself so I remember to follow up when we hear things.

patrickvonplaten commented 2 years ago

Thanks @hawkinsp,

I'll try this one out for fine-tune a Wav2Vec2 model in Flax and report back here if results are similar / equal to PyTorch

frmccann97 commented 2 years ago

Did you get good results?

frmccann97 commented 2 years ago

Thanks @hawkinsp,

I'll try this one out for fine-tune a Wav2Vec2 model in Flax and report back here if results are similar / equal to PyTorch

How did it go?

patrickvonplaten commented 2 years ago

Did you get good results?

Yes it works very well! See: https://github.com/sanchit-gandhi/seq2seq-speech/blob/main/run_flax_speech_recognition_ctc.py

cc @sanchit-gandhi