related-sciences / nxontology-ml

Machine learning to classify ontology nodes
Apache License 2.0
6 stars 0 forks source link

Add mutli-completion support to GPT Tagger #15

Closed yonromai closed 1 year ago

yonromai commented 1 year ago

Following a discussion with @ravwojdyla, he suggested checking how varied the outputs of GPT-4 get when generating several completions for a single prompt. This PR adds support for the n parameter of openai's API.

Here is an example of output (from the updated readme):


from pprint import pprint

from nxontology_ml.data import get_efo_otar_slim
from nxontology_ml.gpt_tagger import TaskConfig, GptTagger
from nxontology_ml.utils import ROOT_DIR

# Create a config for EFO nodes labelling
config = TaskConfig(
    name="precision",
    prompt_path=ROOT_DIR / "prompts/precision_v1.txt",
    openai_model_name="gpt-4",
    node_attributes=["efo_id", "efo_label", "efo_definition"],
    model_n=3,
)

# Get a few EFO nodes
nxo = get_efo_otar_slim()
nodes = (nxo.node_info(node) for node in sorted(nxo.graph)[:20])

# Get their labels
tagger = GptTagger.from_config(config)
for ln in tagger.fetch_labels(nodes):
    print(f"{ln.node_efo_id}: {ln.labels}")

# Inspect metrics
print("\nTagger metrics:")
pprint(tagger.get_metrics())

You should get an output similar to:

DOID:0050890: ['medium', 'medium', 'medium']
DOID:10113: ['low', 'low', 'low']
DOID:10718: ['low', 'low', 'low']
DOID:13406: ['medium', 'medium', 'medium']
DOID:1947: ['low', 'low', 'low']
DOID:7551: ['low', 'low', 'low']
EFO:0000094: ['high', 'high', 'high']
EFO:0000095: ['high', 'high', 'high']
EFO:0000096: ['medium', 'medium', 'medium']
EFO:0000174: ['high', 'high', 'high']
EFO:0000178: ['high', 'medium', 'medium']
EFO:0000180: ['low', 'low', 'medium']
EFO:0000181: ['high', 'medium', 'high']
EFO:0000182: ['high', 'medium', 'high']
EFO:0000183: ['high', 'medium', 'medium']
EFO:0000186: ['high', 'high', 'high']
EFO:0000191: ['high', 'high', 'high']
EFO:0000195: ['low', 'medium', 'medium']
EFO:0000196: ['high', 'high', 'high']
EFO:0000197: ['high', 'medium', 'medium']

Tagger metrics:
Counter({'ChatCompletion/total_tokens': 3543,
         'ChatCompletion/prompt_tokens': 3009,
         'ChatCompletion/completion_tokens': 534,
         'Cache/get': 20,
         'Cache/misses': 20,
         'ChatCompletion/records_processed': 20,
         'Cache/set': 20,
         'ChatCompletion/create_requests': 1})