Describe the bug
If you log a Text Classification dataset with explanations that comes from a HF dataset (after a transformation with to_datasets(), it might contain explanation values like:
So it seems that transforming these dicts in Arrow, automatically fills missing values.
The final error when logging is:
ValidationError: 1 validation error for TokenAttributions
attributions -> POSITIVE
none is not an allowed value (type=type_error.none.not_allowed)
To Reproduce
Run this:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers_interpret import SequenceClassificationExplainer
from datasets import load_dataset
import rubrix as rb
from rubrix import TokenAttributions
# Load Stanford sentiment treebank test set
dataset = load_dataset("sst", "default", split="test")
# Let's use a sentiment classifier fine-tuned on sst
model_name = "distilbert-base-uncased-finetuned-sst-2-english"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Define the explainer using transformers_interpret
cls_explainer = SequenceClassificationExplainer(model, tokenizer)
records = []
for example in dataset.select(range(10)):
# Build Token attributions objects
word_attributions = cls_explainer(example["sentence"])
token_attributions = [
TokenAttributions(
token=token,
attributions={cls_explainer.predicted_class_name: score}
) # ignore first (CLS) and last (SEP) tokens
for token, score in word_attributions[1:-1]
]
# Build Text classification records
record = rb.TextClassificationRecord(
text=example["sentence"],
prediction=[(cls_explainer.predicted_class_name, cls_explainer.pred_probs)],
explanation={"text": token_attributions},
)
records.append(record)
rb_dataset = rb.DatasetForTextClassification(records)
hf_dataset = rb.DatasetForTextClassification(records).to_datasets()
rb.log(rb.DatasetForTextClassification.from_datasets(hf_dataset))
Or use this from the Hub:
import rubrix as rb
from datasets import load_dataset
ds = load_dataset("rubrix/transformer_interpret_example_sst", split="test")
rb.log(rb.DatasetForTextClassification.from_datasets(ds), name="transformer_interpret_example_sst")
Expected behavior
Recovering a dataset with explanations from a HF dataset (the Hub for example) shouldn't fail.
Screenshots
If applicable, add screenshots to help explain your problem.
Environment (please complete the following information):
OS [e.g. iOS]:
Browser [e.g. chrome, safari]:
Rubrix Version [e.g. 0.10.0]:
ElasticSearch Version [e.g. 7.10.2]:
Docker Image (optional) [e.g. rubrix:v0.10.0]:
Additional context
Add any other context about the problem here.
Describe the bug If you log a Text Classification dataset with explanations that comes from a HF dataset (after a transformation with
to_datasets()
, it might contain explanation values like:In the original Rubrix dataset, the attributions were:
So it seems that transforming these dicts in Arrow, automatically fills missing values.
The final error when logging is:
To Reproduce
Run this:
Or use this from the Hub:
Expected behavior Recovering a dataset with explanations from a HF dataset (the Hub for example) shouldn't fail.
Screenshots If applicable, add screenshots to help explain your problem.
Environment (please complete the following information):
Additional context Add any other context about the problem here.