related-sciences / nxontology-ml

Machine learning to classify ontology nodes
Apache License 2.0
6 stars 0 forks source link

Add GPT-4 experiment on 500 samples #21

Closed yonromai closed 11 months ago

yonromai commented 12 months ago

This PR contains a notebook experiment where the GPT-4 tagger is used on a random sample of 500 nodes. The updated notebook is browsable here.

Performance

⇨ The performance of the GPT-4 tagger on the BiasedMAE metric is ~0.267

For reference (lower is better):

The proportions of GPT-4's output classes seem well aligned with the true labels:

GPT-4 labels proportions:
    high=0.47, medium=0.36, low=0.17
Samples (true) labels proportions:
    01-disease-subtype=0.51, 02-disease-root=0.38, 03-disease-area=0.11

Comments/Remarks

Possible future steps?


⇨ @eric-czech @dhimmel @ravwojdyla Shall we brainstorm and figure out the next steps?

ravwojdyla commented 12 months ago

The experiment hasn't been run on all the nodes (I wanted to checkin with you before tagging the entire dataset)

Great idea to validate it first on a smaller sample! ❤️

The proportions of GPT-4's output classes seem well aligned with the true labels:

👍

The completion results have been high variance (for example "NA", "high+xml" labels returned by the model)

That's interesting - how do you interpret it? Is it a sign that the problem at least as stated right now is difficult for the model? Also I like the idea of mapping the completion labels to class probability!

Performance/BiasedMAE metric

Somewhere in here, is there a performance broken down by class? I essentially curious to check for Simpson's paradox.

Iterate on the prompt (explain & slow down reasoning / chain of thoughts?)

Apart from maybe debugging some specific cases to make sure there's no big issues, if we want GPT tagger to perform better, prompt engineering would certainly be an interesting improvement. And specifically the CoT ^1 Re https://github.com/related-sciences/nxontology-ml/pull/20 and other potential issues with GPT skipping or duplicating records, maybe we could include a CSV (in the prompt) with specific record IDs and missing precision column (there are only 2 columns), and ask GPT to fill in the precision in the CSV, I suspect it would be much less likely for GPT to produce "corrupted" output.

eric-czech commented 12 months ago

This PR contains a notebook experiment where the GPT-4 tagger is used on a random sample of 500 nodes

🎉

The experiment hasn't been run on all the nodes (I wanted to checkin with you before tagging the entire dataset)

Do you have some examples you can share before scaling up? Specifically, I mean logs of prompts being sent and responses received for them. I think this is all within the realm of expectation, but a second look could help.

Iterate on the prompt

I'm not opposed to adding more few-shot examples. It would help to have some examples like I mentioned above ... we might be able to add in "difficult" cases as a part of the prompt to some net positive effect.

Try adding the GPT-4 output as a feature in the CatBoost model

Highest priority next step IMO, other than analyzing some of the misclassifications first to see if there isn't something obviously amiss (i.e. in your list: "Leverage difficult training records to iterate on the few shot examples in GPT-4 prompt"). I'm not all surprised this performs worse than catboost model with only topological features, and I think the most interesting comparison will be Catboost model with topological features + pca64 text embedding vs Catboost model with topological features + GPT-4 classifications.

dhimmel commented 12 months ago

Nice. I'm curious to see some prompts where GPT4 provides the wrong answer (and perhaps seeing if CoT helps or makes the cause of the failure more transparent).

Also excited for:

Try adding the GPT-4 output as a feature in the CatBoost model

yonromai commented 12 months ago

Thanks a lot for the helpful comments!


@ravwojdyla

That's interesting - how do you interpret it? Is it a sign that the problem at least as stated right now is difficult for the model?

I wonder, I guess these kinds of hallucinations are to be expected with these LLMs? The sampling randomness will eventually bring out some weird results with enough datapoints.

Also I like the idea of mapping the completion labels to class probability!

Right - on a similar note, I wonder if treating the problem as a regression/ranking problem could yield better results (both for the GPT and the catboost approaches) - although coming up with supervision would likely be more tricky & arbitrary than with the current labels.

Somewhere in here, is there a performance broken down by class? I essentially curious to check for Simpson's paradox.

I haven't look at the results by class (beyond class frequency). It's not much work, can certainly add that to the notebook

Apart from maybe debugging some specific cases to make sure there's no big issues, if we want GPT tagger to perform better, prompt engineering would certainly be an interesting improvement. And specifically the CoT 1 Re https://github.com/related-sciences/nxontology-ml/pull/20 and other potential issues with GPT skipping or duplicating records

I agree. I guess CoT would probably bring some significant structural differences to the prompt (i.e move away for the current batched id,lbl structure). I wonder if adding a "reason" column would help the model spread reasoning across more tokens (and make Karpathy proud of us)

maybe we could include a CSV (in the prompt) with specific record IDs and missing precision column (there are only 2 columns), and ask GPT to fill in the precision in the CSV, I suspect it would be much less likely for GPT to produce "corrupted" output.

Interesting, would you mind giving an example?

Somethings that I'm not following here is that for each input node, we provide text fields that the model use to classify each node and we probably want to model to only output the labels for the nodes (and not the text fields). FWIW here is an example of filled prompt.


@eric-czech

Do you have some examples you can share before scaling up? Specifically, I mean logs of prompts being sent and responses received for them. I think this is all within the realm of expectation, but a second look could help.

For sure, will fetch a few more samples and provide that.

I'm not opposed to adding more few-shot examples. It would help to have some examples like I mentioned above ... we might be able to add in "difficult" cases as a part of the prompt to some net positive effect.

Perhaps I can look into the datapoints where the model is far off? (e.g. high confidence in the wrong class)

I think the most interesting comparison will be Catboost model with topological features + pca64 text embedding vs Catboost model with topological features + GPT-4 classifications.

I agree. This will require labelling all the class 1-3 nodes, so perhaps we should see if we want to iterate on the prompt before proceeding with the extensive labelling task?


@dhimmel

Nice. I'm curious to see some prompts where GPT4 provides the wrong answer (and perhaps seeing if CoT helps or makes the cause of the failure more transparent).

Cool! Will work on it

ravwojdyla commented 12 months ago

maybe we could include a CSV (in the prompt) ...

Interesting, would you mind giving an example?

@yonromai sure, here's an example for DOID:0050890 and EFO:0006792:

<EXISTING FEW-SHOT CONTEXT>

--- BEGIN RECORDS ---
- id: DOID:0050890
  label: synucleinopathy
  definition: A neurodegenerative disease that is characterized by the abnormal accumulation of aggregates of alpha-synuclein protein in neurons, nerve fibres or glial cells. [url:http://en.wikipedia.org/wiki/Synucleinopathies ]
- id: EFO:0006792
  label: Lewy body dementia
  definition: A progressive form of dementia characterized by the presence of protein deposits called Lewy bodies in the midbrain and cerebral cortex, and loss of cholinergic and dopaminergic neurons. The signs and symptoms overlap with Alzheimer and Parkinson disease.
--- END RECORDS ---

And below is the expected output CSV data with the `precision` column missing (`|` is the delimiter). Respond with EXACTLY the same CSV but with the `precision` column filled-in. You MUST assign precision to all records. The only valid values in the `precision` column are: `high`, `medium`, `low`. The order of the records MUST be the same as the expected output below. Respond with ONLY the CSV content, do not include explanation of any kind. Expected output:

id|precision
DOID:0050890|
EFO:0006792|

Example output from ChatGPT (GPT-4):

id|precision DOID:0050890|medium EFO:0006792|medium


Here's an idea to allow CoT and batching:

<EXISTING FEW-SHOT CONTEXT>

--- BEGIN RECORDS ---
- id: DOID:0050890
  label: synucleinopathy
  definition: A neurodegenerative disease that is characterized by the abnormal accumulation of aggregates of alpha-synuclein protein in neurons, nerve fibres or glial cells. [url:http://en.wikipedia.org/wiki/Synucleinopathies ]
- id: EFO:0006792
  label: Lewy body dementia
  definition: A progressive form of dementia characterized by the presence of protein deposits called Lewy bodies in the midbrain and cerebral cortex, and loss of cholinergic and dopaminergic neurons. The signs and symptoms overlap with Alzheimer and Parkinson disease.
--- END RECORDS ---

As a detail-oriented, and smart disease ontology expert, for each record, you must provide a paragraph where you assign the precision and provide your reasoning, think step by step, be methodological. The order of the paragraphs must be the same as the order in the list above.

Finally when you are done with ALL the explanations for ALL records, write `<END OF COT>`, followed by a CSV  (`|` is the delimiter) below but with the `precision` column filled-in. You MUST assign precision to all records, use your reasoning above. The only valid values in the `precision` column are: `high`, `medium`, `low`. The order of the records MUST be the same as the expected output below. Respond with ONLY the CSV content, do not include explanation of any kind. Expected output:

id|precision
DOID:0050890|
EFO:0006792|

Example output from ChatGPT (GPT-4):

Record 1: synucleinopathy Definition: A neurodegenerative disease that is characterized by the abnormal accumulation of aggregates of alpha-synuclein protein in neurons, nerve fibres or glial cells.

Reasoning: The description of synucleinopathy refers to a broad category of neurodegenerative diseases based on a shared pathological characteristic, i.e., the accumulation of alpha-synuclein protein aggregates. This description does not go into the specifics of any one particular disease caused by this phenomenon. This general approach seems in line with descriptions that belong to the low precision category, such as "acquired metabolic disease" which is defined by a range of metabolic issues.

Precision: low

Record 2: Lewy body dementia Definition: A progressive form of dementia characterized by the presence of protein deposits called Lewy bodies in the midbrain and cerebral cortex, and loss of cholinergic and dopaminergic neurons. The signs and symptoms overlap with Alzheimer and Parkinson disease.

Reasoning: Lewy body dementia is described as a progressive form of dementia with specific characteristics, namely, the presence of Lewy bodies. The description provides more specificity than the broader categories like synucleinopathy but doesn't narrow down to a very specific and homogeneous patient population like the examples given for the high precision category. Its definition overlaps with other disease manifestations such as Alzheimer's and Parkinson's disease. This places it at a mid-point between general and highly specific definitions. This seems most in line with descriptions in the medium precision category, like "inflammatory bowel disease" which includes a range of related conditions.

Precision: medium

\<END OF COT>

id|precision DOID:0050890|low EFO:0006792|medium

You can then take the output and split by <END OF COT>, and:

eric-czech commented 12 months ago

Perhaps I can look into the datapoints where the model is far off? (e.g. high confidence in the wrong class) perhaps we should see if we want to iterate on the prompt before proceeding with the extensive labelling task?

That ordering definitely makes sense to me; tune the prompt a bit if possible, then scale it up more.

yonromai commented 11 months ago

(Merging this PR, will iterate on a new branch / PR)