related-sciences / nxontology-ml

Machine learning to classify ontology nodes
Apache License 2.0
6 stars 0 forks source link

Compare few-shot GPT4 features to embedding features for EFO term precision classification #8

Closed eric-czech closed 11 months ago

eric-czech commented 1 year ago

@yonromai suggested trying to embed the text descriptions, labels, aliases, etc. associated with EFO terms and using those for embeddings as a part of https://github.com/related-sciences/nxontology-ml/issues/2.

It would be very interesting to see how a model like the one in https://github.com/related-sciences/nxontology-ml/pull/7 improves with embedding features by comparison to a model with only the few-shot labels in https://github.com/related-sciences/nxontology-ml/pull/6.

The LLM-derived features will definitely be harder to maintain/generate, but on the other hand I know the labels we provided in https://github.com/related-sciences/nxontology-ml/pull/5 are not perfect and I expect that the few-shot features will be more helpful in figuring out which ones are most likely to be mislabeled and why (since they can be directly compared). Nevertheless, contrasting the predictive value of the two could potentially be an important determining factor for how this project, or at least #2 , evolves.

eric-czech commented 1 year ago

FYI @yonromai, this was the embedding model I had in mind when we last spoke: https://huggingface.co/michiyasunaga/BioLinkBERT-base.

That's from a top-tier group in the NLP space and it's the model submitted by the first author (michiyasunaga) on LinkBERT: Pretraining Language Models with Document Links (Mar. 2022). The reported improvements on a recent SOTA model (PubMedBERT) are substantial, so it might be worth kicking the tires on it.

And to be clear, I have no allegiances to this over a LLaMA-derived model, OpenAI or some KG-based approach. Any performance baseline using embeddings would be helpful.

yonromai commented 1 year ago

cc: @eric-czech @dhimmel

TL;DR:

Idea behind the new features:

Outcome:

Comments:

With embedding features

Indexing Vectors: 100%|██████████| 17331/17331 [15:50<00:00, 18.40it/s]
Learning rate set to 0.091517
0:  learn: 1.2545993    total: 87.9ms   remaining: 1m 27s
100:    learn: 0.3817487    total: 1.74s    remaining: 15.4s
200:    learn: 0.3448022    total: 3.38s    remaining: 13.5s
300:    learn: 0.3226222    total: 5.06s    remaining: 11.7s
400:    learn: 0.3037497    total: 6.74s    remaining: 10.1s
500:    learn: 0.2870626    total: 8.47s    remaining: 8.43s
600:    learn: 0.2722362    total: 10.2s    remaining: 6.75s
700:    learn: 0.2571492    total: 11.9s    remaining: 5.07s
800:    learn: 0.2440707    total: 13.6s    remaining: 3.37s
900:    learn: 0.2329543    total: 15.3s    remaining: 1.68s
999:    learn: 0.2212473    total: 17s  remaining: 0us
> Feature importance:
                     Feature Id  Importances
0                        prefix    13.154424
1                    L4-support     8.085884
2                       n_roots     5.246222
3                    L2-support     4.897259
4          intrinsic_ic_sanchez     4.789902
5                    L1-support     4.645040
6                         depth     4.092159
7                   n_ancestors     3.673886
8   intrinsic_ic_sanchez_scaled     3.632503
9                        L2-min     3.356392
10                       L1-min     3.200558
11           xref__mondo__count     2.781731
12                   L3-support     2.633530
13                n_descendants     2.073415
14            xref__omim__count     1.899459
15                 intrinsic_ic     1.737778
16                       L3-min     1.673483
17                       L4-min     1.573646
18        xref__orphanet__count     1.524709
19            xref__ncit__count     1.478501
20          intrinsic_ic_scaled     1.456620
21                        L2-q1     1.455818
22                    n_parents     1.408397
23            xref__gard__count     1.368191
24            xref__doid__count     1.347312
25            xref__mesh__count     1.225176
26            xref__umls__count     1.107692
27          xref__omimps__count     1.101323
28                        L1-q1     0.880555
29            xref__icd9__count     0.803441
30                       L3-max     0.754765
31           xref__icd10__count     0.745556
32                       L1-med     0.744894
33                       L4-max     0.730284
34                   n_children     0.698697
35          xref__meddra__count     0.666259
36                       L2-max     0.643364
37        xref__snomedct__count     0.634071
38                        L3-q1     0.606200
39                       L3-med     0.597636
40                       L2-med     0.588662
41                        L4-q3     0.583918
42                is_gwas_trait     0.558833
43                        L2-q3     0.534196
44                        L3-q3     0.524524
45                        L4-q1     0.516574
46                       L1-max     0.450720
47                        L1-q3     0.402829
48                     n_leaves     0.377648
49                       L4-med     0.335366

> Classification report:
                    precision    recall  f1-score   support

01-disease-subtype       0.79      0.84      0.81      1085
   02-disease-root       0.70      0.66      0.68       797
   03-disease-area       0.79      0.72      0.76       257
    04-non-disease       0.98      0.98      0.98       920

          accuracy                           0.82      3059
         macro avg       0.82      0.80      0.81      3059
      weighted avg       0.82      0.82      0.82      3059

Without embedding features

Learning rate set to 0.091517
0:  learn: 1.2510844    total: 71.5ms   remaining: 1m 11s
100:    learn: 0.4182990    total: 1.06s    remaining: 9.43s
200:    learn: 0.3891503    total: 2.07s    remaining: 8.24s
300:    learn: 0.3713912    total: 3.05s    remaining: 7.09s
400:    learn: 0.3581384    total: 4.03s    remaining: 6.03s
500:    learn: 0.3451430    total: 5.01s    remaining: 4.99s
600:    learn: 0.3334726    total: 5.98s    remaining: 3.97s
700:    learn: 0.3224103    total: 7s   remaining: 2.98s
800:    learn: 0.3125697    total: 8.03s    remaining: 2s
900:    learn: 0.3042341    total: 9.04s    remaining: 993ms
999:    learn: 0.2964110    total: 10s  remaining: 0us
> Feature importance:
                     Feature Id  Importances
0                        prefix    19.235261
1                       n_roots     7.848109
2                   n_ancestors     7.547822
3                         depth     7.365341
4          intrinsic_ic_sanchez     5.770118
5            xref__mondo__count     4.362500
6             xref__omim__count     3.867349
7             xref__doid__count     3.414090
8         xref__orphanet__count     3.360891
9   intrinsic_ic_sanchez_scaled     3.086075
10            xref__mesh__count     3.016446
11                    n_parents     2.935589
12                is_gwas_trait     2.854420
13            xref__umls__count     2.633490
14            xref__ncit__count     2.608488
15                n_descendants     2.435195
16            xref__gard__count     2.336128
17                   n_children     2.183118
18                 intrinsic_ic     2.152569
19          xref__meddra__count     2.075885
20        xref__snomedct__count     1.854948
21           xref__icd10__count     1.849050
22          intrinsic_ic_scaled     1.791724
23            xref__icd9__count     1.153552
24                     n_leaves     1.142829
25          xref__omimps__count     1.119010

> Classification report:
                    precision    recall  f1-score   support

01-disease-subtype       0.80      0.85      0.82      1195
   02-disease-root       0.69      0.61      0.65       719
   03-disease-area       0.77      0.73      0.75       244
    04-non-disease       0.96      0.99      0.97       901

          accuracy                           0.82      3059
         macro avg       0.81      0.79      0.80      3059
      weighted avg       0.82      0.82      0.82      3059
eric-czech commented 1 year ago

Nice @yonromai!

More analysis need to take place to understand if there is value in these new features (and whether the value added is worth the extra complexity)

I think this is clear above, but that does not include GPT-4 assignments of the labels as features (i.e. from https://github.com/related-sciences/nxontology-ml/pull/6) correct?

I assumed that the embedding size would be too large to be exploited "as is" by the model

I think it would be ok to include the embeddings or a reduction on them (e.g. PCA) as features directly. I like the tree/clustering approach, but my hunch is that it will be hard to show an improvement over that simpler method.

Experiment results

Do you have a sense of how much macro F1 averages vary across resamplings (e.g. with cross_val_score(..., cv=5, scoring='f1_macro'))? It would be helpful to know what kind of performance loss in an ablation experiment like that should rate as substantial.

yonromai commented 1 year ago

I think this is clear above, but that does not include GPT-4 assignments of the labels as features (i.e. from https://github.com/related-sciences/nxontology-ml/pull/6) correct?

That's right, I think we should do that next.

I think it would be ok to include the embeddings or a reduction on them (e.g. PCA) as features directly. I like the tree/clustering approach, but my hunch is that it will be hard to show an improvement over that simpler method.

Sure, I'll give it a try!

Do you have a sense of how much macro F1 averages vary across resamplings (e.g. with cross_val_score(..., cv=5, scoring='f1_macro'))? It would be helpful to know what kind of performance loss in an ablation experiment like that should rate as substantial.

That's a great question! I just started using ROC AUC & MAE's from #9 to look into the performance of the model. I'll spend a little bit of time in notebook land looking at how features & model parameters influences metrics.

eric-czech commented 1 year ago

looking at how features & model parameters influences metrics

Awesome, sounds good! So we're clear though, I'm proposing that we compare distributions of F1, ROC, MAE, etc. scores between models where the distributions come from multiple evaluations of those metrics for different folds. Given that this dataset is small, I think we'll need that help understand what changes are significant. Would you agree?

yonromai commented 1 year ago

Yes totally agree, the idea is to "repeat the experiment" of training the same model on different (stratified) folds of the training set to get an idea of the spread of metrics. Then we can use this spread to have an idea of the significance of the metrics calculated once we change the model/features. Is that what you meant?

eric-czech commented 1 year ago

Is that what you meant?

Indeed 👍

yonromai commented 1 year ago

Okay so after some time in notebook land here is gist of what I found out:

TL;DR

@eric-czech both of your hunches were 💯 :

@dhimmel Implementing the MAE (with the class biases suggested by @eric-czech) has proven very useful!

Some results:

image For comparison:

More details about the best performing model (/ Food for thoughts)

The model seems to max out on BiasedMAE on the training data (not on the eval/CV one):

image

The model seems to slightly overfit the objective function:

image

More details..

For more details about the experiments & findings, take a look at the notebook. All the (non-production ready) code is in my branch.

yonromai commented 1 year ago

@eric-czech I think now I have enough understanding about the performance of the pre-GPT model that it'd be worth running the training data through the GPT4 prompt you provided and see if that does better than the model out of the box!

I can run some estimation of the cost of the procedure if useful.

dhimmel commented 1 year ago

Nice finding that PCA is working better on the node text embeddings than KNN and that 64 dimensions captures much of the performance benefit.

it'd be worth running the training data through the GPT4 prompt

I'm excited to see how the GPT4 features perform!

eric-czech commented 1 year ago

here is gist of what I found out

Very nice @yonromai! Great experimental setup and it's excellent to see some clear separation between those models.

Some results:

For posterity, I think it would be helpful to say more about what the lda_d7 and knn_d7 configurations were at this level. Presumably knn_d7 was the method in https://github.com/related-sciences/nxontology-ml/issues/8#issuecomment-1679531083. What was lda_d7 though?

Noting the current details in the notebook:

Screen Shot 2023-08-22 at 8 29 27 AM

running the training data through the GPT4 prompt you provided and see if that does better than the model out of the box!

Awesome -- I'd love to see how it performs on its own and when included as a feature with the other baseline features in a catboost gbm.

More details about the best performing model

OOC what is that UI you're looking at there? I don't see any obvious hints in https://github.com/related-sciences/nxontology-ml/tree/romain/embeddings/experimentation.

yonromai commented 1 year ago

For posterity, I think it would be helpful to say more about what the lda_d7 and knn_d7 configurations were at this level. Presumably knn_d7 was the method in https://github.com/related-sciences/nxontology-ml/issues/8#issuecomment-1679531083. What was lda_d7 though?

@eric-czech Noted, I'll add more details in the notebook. (The LDA code directly applies Sklearn's LDA, similar to the PCA - see this code)

I'd like to cleanup the code I have in my branch and merge it into the main branch. I'm probably going to end up deleting a lot of the code (e.g. the KNN part) in the near future but that way it'll be saved in git history (along with the experimental setup).

@dhimmel: Would that fine with you? (It's gonna be quite a big PR :( )

OOC what is that UI you're looking at there? I don't see any obvious hints in https://github.com/related-sciences/nxontology-ml/tree/romain/embeddings/experimentation.

The code which displays the model metrics is in the "CatBoost's MetricVisualizer" section of the notebook but it's JavaScript so it doesn't render in GH.

dhimmel commented 1 year ago

@dhimmel: Would that fine with you? (It's gonna be quite a big PR :( )

Yes sounds good.

dhimmel commented 1 year ago

The code which displays the model metrics is in the "CatBoost's MetricVisualizer" section of the notebook but it's JavaScript so it doesn't render in GH.

Sometimes this will render in nbviewer, but not in this case.

yonromai commented 11 months ago

Are there short terms plans to work on this or is it appropriate to close this issue?

dhimmel commented 11 months ago

Given that GPT assignments were inferior to text embedding features and didn't add much when combined, I don't think we need to use GPT features at all. Saves on cost and complexity.

eric-czech commented 11 months ago

I don't think we need to use GPT features at all

I definitely agree. Noting https://github.com/related-sciences/nxontology-ml/pull/34#issue-1912661241 as the most recent experiment at TOW that still had these features.