argilla-io / argilla

Argilla is a collaboration tool for AI engineers and domain experts to build high-quality datasets
https://docs.argilla.io
Apache License 2.0
3.96k stars 375 forks source link

Logging a HF dataset with explanations fails due to conversion issues #1725

Open dvsrepo opened 2 years ago

dvsrepo commented 2 years ago

Describe the bug If you log a Text Classification dataset with explanations that comes from a HF dataset (after a transformation with to_datasets(), it might contain explanation values like:

explanation': {'text': [{'attributions': {'NEGATIVE': 0.4228191208527419, 'POSITIVE': None},

In the original Rubrix dataset, the attributions were:

explanation': {'text': [{'attributions': {'NEGATIVE': 0.4228191208527419}..

So it seems that transforming these dicts in Arrow, automatically fills missing values.

The final error when logging is:

ValidationError: 1 validation error for TokenAttributions
attributions -> POSITIVE
  none is not an allowed value (type=type_error.none.not_allowed)

To Reproduce

Run this:

from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers_interpret import SequenceClassificationExplainer
from datasets import load_dataset

import rubrix as rb
from rubrix import TokenAttributions

# Load Stanford sentiment treebank test set
dataset = load_dataset("sst", "default", split="test")

# Let's use a sentiment classifier fine-tuned on sst
model_name = "distilbert-base-uncased-finetuned-sst-2-english"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Define the explainer using transformers_interpret
cls_explainer = SequenceClassificationExplainer(model, tokenizer)

records = []
for example in dataset.select(range(10)):

    # Build Token attributions objects 
    word_attributions = cls_explainer(example["sentence"])
    token_attributions = [ 
        TokenAttributions(
            token=token, 
            attributions={cls_explainer.predicted_class_name: score}
        ) # ignore first (CLS) and last (SEP) tokens
        for token, score in word_attributions[1:-1]
    ]
    # Build Text classification records
    record = rb.TextClassificationRecord(
        text=example["sentence"],
        prediction=[(cls_explainer.predicted_class_name, cls_explainer.pred_probs)],
        explanation={"text": token_attributions},
    )
    records.append(record)

rb_dataset = rb.DatasetForTextClassification(records)
hf_dataset = rb.DatasetForTextClassification(records).to_datasets()
rb.log(rb.DatasetForTextClassification.from_datasets(hf_dataset))

Or use this from the Hub:

import rubrix as rb

from datasets import load_dataset

ds = load_dataset("rubrix/transformer_interpret_example_sst", split="test")

rb.log(rb.DatasetForTextClassification.from_datasets(ds), name="transformer_interpret_example_sst")

Expected behavior Recovering a dataset with explanations from a HF dataset (the Hub for example) shouldn't fail.

Screenshots If applicable, add screenshots to help explain your problem.

Environment (please complete the following information):

Additional context Add any other context about the problem here.

github-actions[bot] commented 2 years ago

This issue is stale because it has been open for 30 days with no activity.

nataliaElv commented 5 months ago

@dvsrepo Is this still relevant for 2.0?