NVIDIA / audio-flamingo

PyTorch implementation of Audio Flamingo: A Novel Audio Language Model with Few-Shot Learning and Dialogue Abilities.
MIT License
173 stars 10 forks source link

Which sentence BERT is used? #7

Closed jasonppy closed 2 months ago

jasonppy commented 2 months ago

Hi,

Regarding the evaluation of FSD50k, sentence BERT is used to calculate approximate F1. The original description from the paper is the following:

Note that we define F1approx to measure inexact but similar predicted labels in FSD50k, where we consider the prediction to be correct if the sentence BERT similarity between output and ground truth is > 0.8

I had 3 questions:

  1. Is the sentence-transformer package used? and which model is used exactly?
  2. Is the reported score mean F1 score? - the average of F1 score for each QA.
  3. Is the processes the following? - GT: "Electric_guitar,Guitar,Plucked_string_instrument,Musical_instrument,Music", Pred: "Acoustic_guitar,Guitar,Musical_instrument,Music", and we first split to list as GT_list: ['Electric_guitar', 'Guitar', 'Plucked_string_instrument', 'Musical_instrument', 'Music'], Pred_list: ['Acoustic_guitar', 'Guitar', 'Musical_instrument', 'Music'], and then use sentence bert to embed each item in the list, and calculate similarity?

Thanks for your time!

zhifengkongnv commented 2 months ago

You're mostly right. See the evaluation code for details

def eval_MultiEventClassification_similarity(results, verbose=False):
    # https://www.sbert.net/
    from sentence_transformers import SentenceTransformer, util
    embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
    threshold = 0.8

    dic = {
        "f1": [],
        "precision": [],
        "recall": []
    }

    def eval_each(prompt, ground_truth, output):
        # split into words by ",", remove extra blank, lower case, no "-", stemming, doing set
        gt_words = [x.strip().lower().replace('-', ' ') for x in ground_truth.split(',')]
        output = [x.strip().lower().replace('-', ' ') for x in output.split(',')]

        if len(gt_words) == 0:
            return 

        if len(output) == 0:
            f1, precision, recall = 0.0, 0.0, 0.0

        else:
            ground_truth_embeddings = embedding_model.encode(gt_words, convert_to_tensor=True)
            prediction_embeddings = embedding_model.encode(output, convert_to_tensor=True)

            similarity_matrix = util.pytorch_cos_sim(ground_truth_embeddings, prediction_embeddings)
            match_matrix = similarity_matrix >= threshold

            TP = match_matrix.sum().item()
            FP = len(output) - TP
            FN = len(gt_words) - TP

            precision = TP / (TP + FP) if TP + FP > 0 else 0
            recall = TP / (TP + FN) if TP + FN > 0 else 0
            f1 = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0

        dic["f1"].append(f1)
        dic["precision"].append(precision)
        dic["recall"].append(recall)

    for filename, prompt, ground_truth, output in results:
        eval_each(prompt, ground_truth, output)

    if verbose:
        print("similarity - f1: {:.3f} \pm {:.3f}; precision: {:.3f} \pm {:.3f}; recall: {:.3f}\ pm {:.3f}".format(
            np.mean(dic["f1"]), np.std(dic["f1"]),
            np.mean(dic["precision"]), np.std(dic["precision"]),
            np.mean(dic["recall"]), np.std(dic["recall"])
        ))

    return dic
jasonppy commented 2 months ago

Thanks! When FSD50K is downloaded from the official zendo source, the label format in FSD50k.ground_truth is "Electric_guitar,Guitar,Plucked_string_instrument,Musical_instrument,Music". However, if we prompt audio-flamingo foundation model with prefix: 'The task is event classification.' and prompt: 'describe this sound in the order from specific to general.' The output of the model will be "electric guitar, guitar, plucked string instrument, musical instrument, music".

The parsing scheme in the code above will parse ground truth into ['electric_guitar', 'guitar', 'plucked_string_instrument', 'musical_instrument', 'music'], and output into ['electric guitar', 'guitar', 'plucked string instrument', 'musical instrument', 'music'] where there is format mismatch.

This is probably the case, but just to confirm - there is also a preprocessing step that replace the underscore "_" in ground truth by space " " right? which is also applied when using it as label when training audio-flamingo on FSD50k.

Btw using your code with underscore handled, I am able to get approximate F1 69.6, which matches the number reported in the paper. Thanks!

zhifengkongnv commented 2 months ago

All underscore symbols (and several other special symbols) are specifically handled through training and inference.