WorldCereal / presto-worldcereal

2 stars 0 forks source link

Improve class weight computation + CatBoost improvements #46

Closed gabrieltseng closed 4 months ago

gabrieltseng commented 5 months ago

TL;DR

Variations in performance for the finetuned model don't seem to trickle down to the sklearn models, so any choice would be okay. Balanced sampling seems fine to go with for now, so this is what we currently have (and it is very marginally better).

Class weights

Results for https://github.com/WorldCereal/presto-worldcereal/pull/46/commits/cb55f93554ec30d1700754d5ecc142ff29c036bb below (wandb run):

Interestingly, fixing the finetuning weights :bug: seems to lead to much worse finetuning results for the maize model (F1: 0.6711 vs. 0.8081 before). This seems to also impact the sklearn models trained on top of this model (except for the random forest and catboost models).

Task Model Head F1 Recall Precision
Crop vs. non crop WorldCereal CatBoost 0.8489 0.8518 0.8462
Crop vs. non crop Full, Crop vs. non crop Finetuned 0.8572 0.9070 0.8126
Crop vs. non crop Full, Crop vs. non crop Random Forest 0.8665 0.8425 0.8919
Crop vs. non crop Full, Crop vs. non crop Logistic Regression 0.8591 0.9113 0.8125
Crop vs. non crop Full, Crop vs. non crop CatBoost 0.8626 0.9079 0.8216
Maize Full, Maize Finetuned 0.6711 0.9266 0.5261
Maize Full, Maize Random Forest 0.8 0.7360 0.8762
Maize Full, Maize Logistic Regression 0.7049 0.9176 0.5723
Maize Full, Maize CatBoost 0.7589 0.8901 0.6614
Maize Full, Crop vs. Non Crop Random Forest 0.6712 0.5453 0.8725
Maize Full, Crop vs. Non Crop Logistic Regression 0.5462 0.8855 0.3949
Maize Full, Crop vs. Non Crop CatBoost 0.6940 0.8739 0.5755

For maize, the imbalance (and therefore the pos_weight) is very high:

>>> weights = torch.from_numpy(train_ds.class_weights)
>>> weights
tensor([0.5490, 5.6046], dtype=torch.float64)
>>> (weights / weights[0])[1]
tensor(10.2093, dtype=torch.float64)

Having such a high weight seems bad to me? I think three solutions would be:

Clamped class weights

Results from the experiment with clamped class weights at finetuning time. Better (maize) finetuning results but actually not a huge difference for the sklearn models.

Task Model Head F1 Recall Precision
Crop vs. non crop WorldCereal CatBoost 0.8489 0.8518 0.8462
Crop vs. non crop Full, Crop vs. non crop Finetuned 0.8601 0.8832 0.8381
Crop vs. non crop Full, Crop vs. non crop Random Forest 0.8624 0.8390 0.8873
Crop vs. non crop Full, Crop vs. non crop Logistic Regression 0.8540 0.9054 0.8081
Crop vs. non crop Full, Crop vs. non crop CatBoost 0.8590 0.90742 0.8181
Maize Full, Maize Finetuned 0.8059 0.8330 0.7806
Maize Full, Maize Random Forest 0.8053 0.7397 0.8836
Maize Full, Maize Logistic Regression 0.7168 0.9187 0.5876
Maize Full, Maize CatBoost 0.7712 0.8897 0.6806
Maize Full, Crop vs. Non Crop Random Forest 0.6675 0.5419 0.8690
Maize Full, Crop vs. Non Crop Logistic Regression 0.5410 0.8803 0.3905
Maize Full, Crop vs. Non Crop CatBoost 0.6822 0.8688 0.5616

Balanced sampling

Results from the experiment with balanced sampling at finetuning time.

Task Model Head F1 Recall Precision
Crop vs. non crop WorldCereal CatBoost 0.8489 0.8518 0.8462
Crop vs. non crop Full, Crop vs. non crop Finetuned 0.8652 0.8841 0.8471
Crop vs. non crop Full, Crop vs. non crop Random Forest 0.8661 0.8424 0.8910
Crop vs. non crop Full, Crop vs. non crop Logistic Regression 0.8581 0.9098 0.8119
Crop vs. non crop Full, Crop vs. non crop CatBoost 0.8631 0.9076 0.8228
Maize Full, Maize Finetuned 0.7460 0.9085 0.6328
Maize Full, Maize Random Forest 0.8068 0.7455 0.8791
Maize Full, Maize Logistic Regression 0.7276 0.9207 0.6015
Maize Full, Maize CatBoost 0.7717 0.8912 0.6805
Maize Full, Crop vs. Non Crop Random Forest 0.6687 0.5416 0.8690
Maize Full, Crop vs. Non Crop Logistic Regression 0.5457 0.8815 0.3952
Maize Full, Crop vs. Non Crop CatBoost 0.6879 0.8737 0.5673
kvantricht commented 5 months ago

balancing the samples in the batch, removing the need for a class weight entirely without messing about with the loss's magnitude: https://github.com/WorldCereal/presto-worldcereal/commit/3cf5e9531bb0efab85eb97240b64e1ca533ee598

This looks really promising IMO. It's how we trained crop type models for the operational European crop mapping project. I honestly didn't know the balancing was also done for finetuning Presto itself. I thought it was only relevant for the downstream classifiers like CatBoost. Because in the end, I think we don't want a separate Presto encoder for each crop type, but finetune in general for crop type mapping (and maybe from crop/no crop we learned already enough?). Could we also try just to take finetuned Presto crop/no-crop and balance only in the sklearn and CatBoost models? (or in a head but then by doing batch balancing)

gabrieltseng commented 4 months ago

I think there is still work to do wrt. CatBoost parameters, but I am going to merge this in for now since this class balancing seems to work.