dreamquark-ai / tabnet

PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf
https://dreamquark-ai.github.io/tabnet/
MIT License
2.61k stars 485 forks source link

Multiclass custom metric weighted f1 score #400

Closed AbrahamBlauvelt closed 2 years ago

AbrahamBlauvelt commented 2 years ago

Hello, I am trying to compare multiple classifiers using the weighted f1 score for a multiclass problem (= 4 classes). Thus, I tried to create the weighted f1 score myself below:

`class my_metric(Metric):

def __init__(self):
    self._name = "f1" # write an understandable name here
    self._maximize = True

def __call__(self, y_true, y_score):
    return f1_score(y_true,y_score, average='weighted')`

However, I get this error message:

Classification metrics can't handle a mix of multiclass and continuous-multioutput targets I know I'm doing something wrong with my f1_score but don't exactly know what. It's probably very obvious but I can't seem to figure it out. Thank you in advance for any reply

Optimox commented 2 years ago

This is related to sklearn's f1_score function. I guess it's about how you handle the predictions. During training, the model will outpout probabilities and not hard predictions.

so you might want to add y_score = np.argmax(y_score, axis=1) before feeding to the f1_scorefunction. (It might be axis=0, I always get that wrong).

To figure out what you should do, print the shapes of the inputs and try outside with fake inputs and see when the f1_score function works.

AbrahamBlauvelt commented 2 years ago

Thank you for your reply, I managed to get it to work! I found out that even after implementing your solution, my metric did not work. I created f1 as my metric but it was still using the old definition of the metric instead of the new one. After a bit of experimenting, I changed its name to f3 and it works now. I guess if you try to overwrite a custom metric you created yourself it won't actually do it (or I did something else wrong, which is very likely). Anyways, this is the working code snippet now:

`class my_metric(Metric):

def __init__(self):
    self._name = "f3" 
    self._maximize = True

def __call__(self, y_true, y_score):
    return f1_score(y_true, np.argmax(y_score, axis = 1), average='weighted') `

Thank you for your help and have a nice weekend!

Optimox commented 2 years ago

@AbrahamBlauvelt good to hear!

You can't overwrite existing classes, but you can definitely name your new metric f1, you probably just need to restart your notebook!