baal-org / baal

Bayesian active learning library for research and industrial usecases.
https://baal.readthedocs.io
Apache License 2.0
854 stars 84 forks source link

Question - Using BAAL for automatic speech recognition (speech to text) #283

Closed ognjenkundacina closed 4 months ago

ognjenkundacina commented 5 months ago

Hello!

Can BAAL be used out of the box for batch active learning for automatic speech recognition models like wav2vec2? If not, can you please give me suggestions on how to implement that?

Thank you!

Dref360 commented 5 months ago

Hello! I'm not super familiar with s2t, but it should be similar to text generation which we support to some extent.

Do you use huggingface or another library? I can probably code something up this week to show case this capability.

ognjenkundacina commented 5 months ago

Thanks for the help! Yes, it should be similar to text generation in terms of active learning criteria, since it is the same type of output.

I am using wav2vec2 (without the ngram language model - a simpler case) from huggingface: https://huggingface.co/docs/transformers/model_doc/wav2vec2

Dref360 commented 5 months ago

Cool! So as expected it's pretty similar to text generation. I recorded a Loom for you and there is also a gist.

Loom Gist

Now that the uncertainty estimation seems to work, the actual active learning loop should work using this tutorial (The part where ActiveLearningLoop is used at the end).

Let me know if this work and I'm excited to know more about your usecase. We're definitly prepared to make changes to the library to fit your usecase.

Cheers!

ognjenkundacina commented 5 months ago

Thank you for the detailed answer! In the next month I will be working on this in more depth and let you know about the findings.

ognjenkundacina commented 5 months ago

Thanks a lot again for the guidance and the resources you've shared; they're incredibly helpful!

I'm actually planning to research the uncertainty estimation of sequences for s2t. I'm thinking of developing a function in BAAL that can process all the sequences (transcriptions) generated through MC Dropout iterations, to compare these sequences to derive new uncertainty measures. For example, given the parameters in the code you provided, the function should have access to 20 transcriptions from MC Dropout iterations alongside a single transcription from the model without dropout for each audio sample. Could you provide any advice on how to implement this function and access the transcriptions as mentioned?

Dref360 commented 4 months ago

It is simple to switch between MC-Dropout and Deterministic inference with Baal. But theoretically, we tend to use the mean prediction (we call this Bayesian Model Average, but we just average the logits). BALD does this by computing the variance in entropy between predictions.

Alternate between MC-Dropout and Deterministic

your_model = # Your S2T model
wrapper = BaalTrainer(...)

with MCDropoutModule(your_model) as model:
    # This is stochastic
    predictions = [model(input) for _ in range(ITERATIONS)]
    wrapper.predict_on_dataset(..., iterations=ITERATIONS)

# this is deterministic
output = model(input)
wrapper.predict_on_dataset(..., iterations=1)

You can then compute the uncertainty how you wish using both the deterministic and stochastic predictions.

To compute the average prediction, we have the ITERATION axis at the end so you can do:

predictions = trainer.prediction_on_dataset(...)
average_pred = predictions.mean(axis=-1)

I hope I understood your message correctly. I'm happy to hop on a call if it helps.

ognjenkundacina commented 4 months ago

Thank you very much, I've managed to implement what I wanted using these instructions!