Open patrickvonplaten opened 2 years ago
Cool, would do you guys prefer? Think we could find some community members to work on adding such a loss actually
In case this is already somewhere developed internally and will soon be open-sourced it'd be great to hear about it :-)
There's a CTC loss in Lingvo/JAX: https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/ctc_objectives.py
Could you use that one?
We're also asking the optax folks what they think. I'll assign this to myself so I remember to follow up when we hear things.
Thanks @hawkinsp,
I'll try this one out for fine-tune a Wav2Vec2 model in Flax and report back here if results are similar / equal to PyTorch
Did you get good results?
Thanks @hawkinsp,
I'll try this one out for fine-tune a Wav2Vec2 model in Flax and report back here if results are similar / equal to PyTorch
How did it go?
Did you get good results?
Yes it works very well! See: https://github.com/sanchit-gandhi/seq2seq-speech/blob/main/run_flax_speech_recognition_ctc.py
cc @sanchit-gandhi
Please:
Community Request for JAX CTC Loss
First of all, I'm not 100% sure whether this repo is the right place or whether we should rather open it in https://github.com/deepmind/optax or https://github.com/google/flax
Many state-of-the-art speech recognition models are fine-tuned using the CTC loss - most notably Wav2Vec2. Both PyTorch and Tensorflow have native support for the CTC loss - see:
People seem to start using TPUs more often and PyTorch's CTC loss is not made for torch/XLA, see: https://github.com/pytorch/xla/issues/2681#issuecomment-894468034 JAX has very good TPU support and is arguably more user-friendly for PyTorch users than TF.
Think this could also be a cool first/second issue.
Gently pinging @marcvanzee @jheek @avital here as well