joeweiss / birdnetlib

A python api for BirdNET-Lite and BirdNET-Analyzer
https://joeweiss.github.io/birdnetlib/
Apache License 2.0
41 stars 14 forks source link

Use retrained BirdNET-Analyzer model with birdnetlib? #108

Closed MaHaWo closed 7 months ago

MaHaWo commented 8 months ago

I'm trying to figure out how to replace models in birdnetlib since we need that for a researc project, and as a first step tried to use a the result of the BirdNet retraining example via bridnetlib:

from birdnetlib import Recording 
from birdnetlib.analyzer import Analyzer 

custom_model_path = '/path/to/BirdNET-Analyzer/checkpoints/custom/Custom_Classifier.tflite'
custom_labels_path = '/path/to/BirdNET-Analyzer/checkpoints/custom/Custom_Classifier_Labels.txt'

recording = Recording(
    Analyzer(classifier_labels_path=custom_labels_path, classifier_model_path=custom_model_path,),
    current / Path("soundscape.wav"),
    min_conf=0.25,
)

recording.analyze()

which however fails at the preparation step(?) of the model with the following error message, apparently indicating that the input data has the wrong shape:

Traceback (most recent call last):
  File "/respective/path/birdnet_experiments/custom_classifier_birdlib.py", line 112, in <module>
    recording.analyze()
  File "/respective/install/path/lib/python3.11/site-packages/birdnetlib/main.py", line 70, in analyze
    self.analyzer.analyze_recording(self)
  File "/respective/install/path/lib/python3.11/site-packages/birdnetlib/analyzer.py", line 340, in analyze_recording
    pred = self.predict_with_custom_classifier(c)[0]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/respective/install/path/lib/python3.11/site-packages/birdnetlib/analyzer.py", line 447, in predict_with_custom_classifier
    C_INTERPRETER.invoke()
  File "/respective/install/path/lib/python3.11/site-packages/tflite_runtime/interpreter.py", line 941, in invoke
    self._interpreter.Invoke()
RuntimeError: /tensorflow/tensorflow/lite/kernels/concatenation.cc:162 t->dims->data[d] != t0->dims->data[d] (1 != 0)Node number 92 (CONCATENATION) failed to prepare.

The version without the custom model works fine. What I did then was to dig around in the code a bit, and then copied the respective code form the model.py file from BirdNET-Analyzer into a child class of Analyzer:

from birdnetlib import Recording 
from birdnetlib.analyzer import Analyzer 

class CustomAnalyzer(Analyzer): 
    """
    Builds on the analyzer class and overrides the 'predict_with_custom_classifier' function with a copy of what BirdNET-Analyzer does (BirdNET-Analyzer/model.py@563)
    Also includes a get_embeddings method that is a copy of the 'embeddings' function in BirdNET-Analyzer (BirdNET-Analyzer/model@600). 
    """
    def get_embeddings(self, sample):
        # this is what the "embeddings" function in the normal BirdNET library does. 
        self.interpreter.resize_tensor_input(self.input_layer_index, [len(data), *data[0].shape])
        self.interpreter.allocate_tensors()

        # Extract feature embeddings
        self.interpreter.set_tensor(self.input_layer_index, np.array(data, dtype="float32"))
        self.interpreter.invoke()
        features = self.interpreter.get_tensor(self.output_layer_index)

        return features 

    def predict_with_custom_classifier(self, sample): 

        # overrides the predict_with_custom_classifier in the 'Analyzer' class 
        # of birdnetlib with what the BirdNET-Analyzer system does. 
        data = np.array([sample], dtype="float32")

        input_details = self.custom_interpreter.get_input_details()

        output_details = self.custom_interpreter.get_output_details()

        C_INPUT_SIZE = input_details[0]["shape"][-1]

        feature_vector = self.get_embeddings(data) if C_INPUT_SIZE != 144000 else data

        self.custom_interpreter.resize_tensor_input(
            self.custom_input_layer_index, [len(feature_vector), *feature_vector[0].shape]
        )

        self.custom_interpreter.allocate_tensors()

        # Make a prediction
        self.custom_interpreter.set_tensor(
            self.custom_input_layer_index, np.array(feature_vector, dtype="float32")
        )

        self.custom_interpreter.invoke()

        prediction = self.custom_interpreter.get_tensor( self.custom_output_layer_index)

        # Logits or sigmoid activations?
        APPLY_SIGMOID = True 
        if APPLY_SIGMOID:
            SIGMOID_SENSITIVITY = 1.0
            prediction = self.flat_sigmoid(
                np.array(prediction), sensitivity=-SIGMOID_SENSITIVITY
            )

        return prediction

recording = Recording(
    CustomAnalyzer(classifier_labels_path=default_labels_path, classifier_model_path=default_model_path, ),
    path/ to / Path("soundscape.wav"),
    min_conf=0.25,
)

recoding.analyze()

With this, things appear to work fine. I'm new to TensorFlow/machine learning, so I have a hard time interpreting the underlying logic that results in this problem.

Is this a symptom of support for the retraining/custom-model stuff being in its early stages?

MaHaWo commented 8 months ago

possibly relevant: #80, #67

joeweiss commented 8 months ago

Hi @MaHaWo, thanks for digging into this, and for posting your working code. I'm away from the lab until Wednesday, so I can't test your workaround in depth (other than on my field laptop).

Are you seeing the same detection results from your modified code as you see with BirdNET's command line script python3 analyze.py --classifier checkpoints/custom/Custom_Classifier.tflite?

If I recall correctly, the custom models I'm using are trained using BirdNET 2.3, so perhaps I've missed a recent change from Stefan and his team.

If you don't mind sharing your model (or any other failing model), here's an upload link: https://www.dropbox.com/request/lTUadB0XVNNJ9HMmny89

I only use provided models internally, and delete them after closing this issue. If you'd rather not upload, that's ok too.

MaHaWo commented 8 months ago

Hello @joeweiss, thanks for the quick reply!

custom_species_list = '/path/to/where/stuff/is/BirdNET-Analyzer/example/species_list.txt'

custom_model_path = '/path/to/where/stuff/is/BirdNET-Analyzer/checkpoints/custom/Custom_Classifier.tflite' custom_labels_path = '/path/to/where/stuff/is/BirdNET-Analyzer/checkpoints/custom/Custom_Classifier_Labels.txt'

make classification with the custom analyzer/model

Load and initialize the BirdNET-Analyzer models.

recording = Recording( CustomAnalyzer(classifier_labels_path=custom_labels_path, classifier_model_path=custom_model_path, ), current / Path("soundscape.wav"), min_conf=0.25, )

recording.analyze()

pd.DataFrame(recording.detections).sort_values(by = "confidence", ascending = True).to_csv("/path/to/where/stuff/is/birdnet_experiments/example_results/custom_results_birdnetlib_experimental.csv")

make classification with BirdNet-Analyzer itself

command = 'python3 /path/to/where/stuff/is/BirdNET-Analyzer/analyze.py --classifier '+custom_model_path+' --i /path/to/where/stuff/is/BirdNET-Analyzer/example --o /path/to/where/stuff/is/BirdNET-Analyzer/example_custom_results --min_conf 0.25 --threads 12 --rtype csv'

subprocess.run(command, shell = True)

check results, put into dataframe for prettiness

birdnet_analyzer = pd.read_csv('/path/to/where/stuff/is/BirdNET-Analyzer/example_custom_results/soundscape.BirdNET.results.csv').sort_values(by='Confidence', ascending = True)

birdnet_lib = pd.read_csv('/path/to/where/stuff/is/birdnet_experiments/example_results/custom_results_birdnetlib_experimental.csv').sort_values(by='confidence', ascending = True) birdnet_lib['confidence'] = birdnet_lib['confidence'].round(4)

print to visually inspect

print("analyzer:", birdnet_analyzer[['Start (s)', 'End (s)', 'Confidence', 'Common name']])

print("birdnet-lib: ",birdnet_lib[['start_time', 'end_time', 'confidence', 'common_name']])

print("equal? ", np.all(birdnet_analyzer[['Start (s)', 'End (s)', 'Confidence', 'Common name']].to_numpy() == birdnet_lib[['start_time', 'end_time', 'confidence', 'common_name']].to_numpy()))


I omitt the output here for brevity, but for me the equality check prints 'True', so it seems the results are equal.