huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.47k stars 26.39k forks source link

Trainer: TypeError: an integer is required (got type NoneType) #16975

Closed loretoparisi closed 2 years ago

loretoparisi commented 2 years ago

System Info

- `transformers` version: 4.18.0
- Platform: Linux-5.4.144+-x86_64-with-Ubuntu-18.04-bionic
- Python version: 3.7.13
- Huggingface_hub version: 0.5.1
- PyTorch version (GPU?): 1.10.0+cu111 (True)
- Tensorflow version (GPU?): 2.8.0 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>

Who can help?

@lys

Information

Tasks

Reproduction

from datasets import load_dataset,Features,Value,ClassLabel
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from transformers import TrainingArguments
import numpy as np
from datasets import load_metric
import torch

class_names = ["cmn","deu","rus","fra","eng","jpn","spa","ita","kor","vie","nld","epo","por","tur","heb","hun","ell","ind","ara","arz","fin","bul","yue","swe","ukr","bel","que","ces","swh","nno","wuu","nob","zsm","est","kat","pol","lat","urd","sqi","isl","fry","afr","ron","fao","san","bre","tat","yid","uig","uzb","srp","qya","dan","pes","slk","eus","cycl","acm","tgl","lvs","kaz","hye","hin","lit","ben","cat","bos","hrv","tha","orv","cha","mon","lzh","scn","gle","mkd","slv","frm","glg","vol","ain","jbo","tok","ina","nds","mal","tlh","roh","ltz","oss","ido","gla","mlt","sco","ast","jav","oci","ile","ota","xal","tel","sjn","nov","khm","tpi","ang","aze","tgk","tuk","chv","hsb","dsb","bod","sme","cym","mri","ksh","kmr","ewe","kab","ber","tpw","udm","lld","pms","lad","grn","mlg","xho","pnb","grc","hat","lao","npi","cor","nah","avk","mar","guj","pan","kir","myv","prg","sux","crs","ckt","bak","zlm","hil","cbk","chr","nav","lkt","enm","arq","lin","abk","pcd","rom","gsw","tam","zul","awa","wln","amh","bar","hbo","mhr","bho","mrj","ckb","osx","pfl","mgm","sna","mah","hau","kan","nog","sin","glv","dng","kal","liv","vro","apc","jdt","fur","che","haw","yor","crh","pdc","ppl","kin","shs","mnw","tet","sah","kum","ngt","nya","pus","hif","mya","moh","wol","tir","ton","lzz","oar","lug","brx","non","mww","hak","nlv","ngu","bua","aym","vec","ibo","tkl","bam","kha","ceb","lou","fuc","smo","gag","lfn","arg","umb","tyv","kjh","oji","cyo","urh","kzj","pam","srd","lmo","swg","mdf","gil","snd","tso","sot","zza","tsn","pau","som","egl","ady","asm","ori","dtp","cho","max","kam","niu","sag","ilo","kaa","fuv","nch","hoc","iba","gbm","sun","war","mvv","pap","ary","kxi","csb","pag","cos","rif","kek","krc","aii","ban","ssw","tvl","mfe","tah","bvy","bcl","hnj","nau","nst","afb","quc","min","tmw","mad","bjn","mai","cjy","got","hsn","gan","tzl","dws","ldn","afh","sgs","krl","vep","rue","tly","mic","ext","izh","sma","jam","cmo","mwl","kpv","koi","bis","ike","run","evn","ryu","mnc","aoz","otk","kas","aln","akl","yua","shy","fkv","gos","fij","thv","zgh","gcf","cay","xmf","tig","div","lij","rap","hrx","cpi","tts","gaa","tmr","iii","ltg","bzt","syc","emx","gom","chg","osp","stq","frr","fro","nys","toi","new","phn","jpa","rel","drt","chn","pli","laa","bal","hdn","hax","mik","ajp","xqa","pal","crk","mni","lut","ayl","ood","sdh","ofs","nus","kiu","diq","qxq","alt","bfz","klj","mus","srn","guc","lim","zea","shi","mnr","bom","sat","szl"]
features = Features({ 'label': ClassLabel(names=class_names), 'text': Value('string')})
num_labels = features['label'].num_classes
data_files = { "train": "train.csv", "test": "test.csv" }
sentences = load_dataset(
     "loretoparisi/tatoeba-sentences",
     data_files=data_files,
     delimiter='\t', 
     column_names=['label', 'text'],
     download_mode="force_redownload"
)
print(sentences)
# You can make this part faster with num_proc=<some int>
sentences = sentences.map(lambda ex: {"label" : features["label"].str2int(ex["label"]) if ex["label"] is not None else None}, features=features)
sentences = sentences.shuffle()

model_name = 'microsoft/xtremedistil-l6-h256-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)
tokenized_datasets = sentences.map(tokenize_function, batched=True)

full_train_dataset = tokenized_datasets["train"]
full_eval_dataset = tokenized_datasets["test"]

device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)

model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
model = model.to(device)
metric = load_metric("accuracy")
def compute_metrics(eval_pred):
    print(eval_pred)
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)
training_args = TrainingArguments("test_trainer",
                                  per_device_train_batch_size=128, 
                                  num_train_epochs=24,learning_rate=3e-05,
                                  evaluation_strategy="epoch")
from transformers import Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=full_train_dataset,
    eval_dataset=full_eval_dataset,
    compute_metrics=compute_metrics,
)
trainer.train()

Stack trace:

<div class="stream output-id-8" style="display: inline; color: rgb(213, 213, 213); font-family: Roboto, Noto, sans-serif; font-size: 14px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; orphans: 2; text-align: start; text-indent: 0px; text-transform: none; white-space: normal; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; background-color: rgb(56, 56, 56); text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;"><div class="output_subarea output_text" style="display: inline;"><pre style="margin-bottom: 0px; margin-top: 0px; display: inline;">***** Running training *****
  Num examples = 8256315
  Num Epochs = 24
  Instantaneous batch size per device = 128
  Total train batch size (w. parallel, distributed &amp; accumulation) = 128
  Gradient Accumulation steps = 1
  Total optimization steps = 1548072
</pre></div></div><div class="display_data output-id-4976" style="color: rgb(213, 213, 213); font-family: Roboto, Noto, sans-serif; font-size: 14px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; orphans: 2; text-align: start; text-indent: 0px; text-transform: none; white-space: normal; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; background-color: rgb(56, 56, 56); text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;"><div class="output_subarea output_html rendered_html"><div><progress value="4942" max="1548072" style="width: 300px; height: 20px; vertical-align: middle;"></progress><span> </span>[ 4942/1548072 50:35 &lt; 263:23:43, 1.63 it/s, Epoch 0.08/24]</div>

Epoch | Training Loss | Validation Loss
-- | -- | --

<p style="margin-bottom: 6px; margin-top: 6px;"></p></div></div><div class="stream output-id-4535" style="display: inline; color: rgb(213, 213, 213); font-family: Roboto, Noto, sans-serif; font-size: 14px; font-style: normal; font-variant-ligatures: normal; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; orphans: 2; text-align: start; text-indent: 0px; text-transform: none; white-space: normal; widows: 2; word-spacing: 0px; -webkit-text-stroke-width: 0px; background-color: rgb(56, 56, 56); text-decoration-thickness: initial; text-decoration-style: initial; text-decoration-color: initial;"><div class="output_subarea output_text" style="display: inline;"><pre style="margin-bottom: 0px; margin-top: 0px; display: inline;">Saving model checkpoint to test_trainer/checkpoint-500
Configuration saved in test_trainer/checkpoint-500/config.json
Model weights saved in test_trainer/checkpoint-500/pytorch_model.bin
Saving model checkpoint to test_trainer/checkpoint-1000
Configuration saved in test_trainer/checkpoint-1000/config.json
Model weights saved in test_trainer/checkpoint-1000/pytorch_model.bin
Saving model checkpoint to test_trainer/checkpoint-1500</pre></div></div>
...
Saving model checkpoint to test_trainer/checkpoint-4500
Configuration saved in test_trainer/checkpoint-4500/config.json
Model weights saved in test_trainer/checkpoint-4500/pytorch_model.bin
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-10-3435b262f1ae>](https://localhost:8080/#) in <module>()
----> 1 trainer.train()

5 frames
[/usr/local/lib/python3.7/dist-packages/transformers/data/data_collator.py](https://localhost:8080/#) in torch_default_data_collator(features)
    113         label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
    114         dtype = torch.long if isinstance(label, int) else torch.float
--> 115         batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
    116     elif "label_ids" in first and first["label_ids"] is not None:
    117         if isinstance(first["label_ids"], torch.Tensor):

TypeError: an integer is required (got type NoneType)

Expected behavior

training complete successfully.
loretoparisi commented 2 years ago

Could be this issue related to this SF answer?

The dataset looks like

Dataset({
    features: ['label', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 8256315
})

and features

{'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
 'label': ClassLabel(num_classes=403, names=['cmn', 'deu', 'rus', 'fra', 'eng', 'jpn', 'spa', 'ita', 'kor', 'vie', 'nld', 'epo', 'por', 'tur', 'heb', 'hun', 'ell', 'ind', 'ara', 'arz', 'fin', 'bul', 'yue', 'swe', 'ukr', 'bel', 'que', 'ces', 'swh', 'nno', 'wuu', 'nob', 'zsm', 'est', 'kat', 'pol', 'lat', 'urd', 'sqi', 'isl', 'fry', 'afr', 'ron', 'fao', 'san', 'bre', 'tat', 'yid', 'uig', 'uzb', 'srp', 'qya', 'dan', 'pes', 'slk', 'eus', 'cycl', 'acm', 'tgl', 'lvs', 'kaz', 'hye', 'hin', 'lit', 'ben', 'cat', 'bos', 'hrv', 'tha', 'orv', 'cha', 'mon', 'lzh', 'scn', 'gle', 'mkd', 'slv', 'frm', 'glg', 'vol', 'ain', 'jbo', 'tok', 'ina', 'nds', 'mal', 'tlh', 'roh', 'ltz', 'oss', 'ido', 'gla', 'mlt', 'sco', 'ast', 'jav', 'oci', 'ile', 'ota', 'xal', 'tel', 'sjn', 'nov', 'khm', 'tpi', 'ang', 'aze', 'tgk', 'tuk', 'chv', 'hsb', 'dsb', 'bod', 'sme', 'cym', 'mri', 'ksh', 'kmr', 'ewe', 'kab', 'ber', 'tpw', 'udm', 'lld', 'pms', 'lad', 'grn', 'mlg', 'xho', 'pnb', 'grc', 'hat', 'lao', 'npi', 'cor', 'nah', 'avk', 'mar', 'guj', 'pan', 'kir', 'myv', 'prg', 'sux', 'crs', 'ckt', 'bak', 'zlm', 'hil', 'cbk', 'chr', 'nav', 'lkt', 'enm', 'arq', 'lin', 'abk', 'pcd', 'rom', 'gsw', 'tam', 'zul', 'awa', 'wln', 'amh', 'bar', 'hbo', 'mhr', 'bho', 'mrj', 'ckb', 'osx', 'pfl', 'mgm', 'sna', 'mah', 'hau', 'kan', 'nog', 'sin', 'glv', 'dng', 'kal', 'liv', 'vro', 'apc', 'jdt', 'fur', 'che', 'haw', 'yor', 'crh', 'pdc', 'ppl', 'kin', 'shs', 'mnw', 'tet', 'sah', 'kum', 'ngt', 'nya', 'pus', 'hif', 'mya', 'moh', 'wol', 'tir', 'ton', 'lzz', 'oar', 'lug', 'brx', 'non', 'mww', 'hak', 'nlv', 'ngu', 'bua', 'aym', 'vec', 'ibo', 'tkl', 'bam', 'kha', 'ceb', 'lou', 'fuc', 'smo', 'gag', 'lfn', 'arg', 'umb', 'tyv', 'kjh', 'oji', 'cyo', 'urh', 'kzj', 'pam', 'srd', 'lmo', 'swg', 'mdf', 'gil', 'snd', 'tso', 'sot', 'zza', 'tsn', 'pau', 'som', 'egl', 'ady', 'asm', 'ori', 'dtp', 'cho', 'max', 'kam', 'niu', 'sag', 'ilo', 'kaa', 'fuv', 'nch', 'hoc', 'iba', 'gbm', 'sun', 'war', 'mvv', 'pap', 'ary', 'kxi', 'csb', 'pag', 'cos', 'rif', 'kek', 'krc', 'aii', 'ban', 'ssw', 'tvl', 'mfe', 'tah', 'bvy', 'bcl', 'hnj', 'nau', 'nst', 'afb', 'quc', 'min', 'tmw', 'mad', 'bjn', 'mai', 'cjy', 'got', 'hsn', 'gan', 'tzl', 'dws', 'ldn', 'afh', 'sgs', 'krl', 'vep', 'rue', 'tly', 'mic', 'ext', 'izh', 'sma', 'jam', 'cmo', 'mwl', 'kpv', 'koi', 'bis', 'ike', 'run', 'evn', 'ryu', 'mnc', 'aoz', 'otk', 'kas', 'aln', 'akl', 'yua', 'shy', 'fkv', 'gos', 'fij', 'thv', 'zgh', 'gcf', 'cay', 'xmf', 'tig', 'div', 'lij', 'rap', 'hrx', 'cpi', 'tts', 'gaa', 'tmr', 'iii', 'ltg', 'bzt', 'syc', 'emx', 'gom', 'chg', 'osp', 'stq', 'frr', 'fro', 'nys', 'toi', 'new', 'phn', 'jpa', 'rel', 'drt', 'chn', 'pli', 'laa', 'bal', 'hdn', 'hax', 'mik', 'ajp', 'xqa', 'pal', 'crk', 'mni', 'lut', 'ayl', 'ood', 'sdh', 'ofs', 'nus', 'kiu', 'diq', 'qxq', 'alt', 'bfz', 'klj', 'mus', 'srn', 'guc', 'lim', 'zea', 'shi', 'mnr', 'bom', 'sat', 'szl'], id=None),
 'text': Value(dtype='string', id=None),
 'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None)}
loretoparisi commented 2 years ago

[UPDATE] I have changed the tokenize function like

def tokenize_function(batch):
    tokens = tokenizer(batch['text'], padding="max_length", truncation=True, max_length=128)
    tokens['label'] = features["label"].str2int(batch['label'])
    return tokens
tokenized_datasets = sentences.map(tokenize_function, batched=True)

and removed the mapping as defined above, but now I'm facing a None label issue:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
[<ipython-input-39-3f04e6ec6f6e>](https://localhost:8080/#) in <module>()
     14     tokens['label'] = features["label"].str2int(batch['label']) if batch["label"] is not None else None
     15     return tokens
---> 16 tokenized_datasets = sentences.map(tokenize, batched=True)

10 frames
[/usr/local/lib/python3.7/dist-packages/datasets/features/features.py](https://localhost:8080/#) in str2int(self, values)
    852                 if value not in self._str2int:
    853                     value = str(value).strip()
--> 854                 output.append(self._str2int[str(value)])
    855             else:
    856                 # No names provided, try to integerize

KeyError: 'None'
loretoparisi commented 2 years ago

Solved filtering None rows

sentences = sentences.filter(lambda example: example['label'] is not None and example['text'] is not None)

and slightly changing the tokenizer

from datasets import load_dataset,Features,Value,ClassLabel

class_names = ["cmn","deu","rus","fra","eng","jpn","spa","ita","kor","vie","nld","epo","por","tur","heb","hun","ell","ind","ara","arz","fin","bul","yue","swe","ukr","bel","que","ces","swh","nno","wuu","nob","zsm","est","kat","pol","lat","urd","sqi","isl","fry","afr","ron","fao","san","bre","tat","yid","uig","uzb","srp","qya","dan","pes","slk","eus","cycl","acm","tgl","lvs","kaz","hye","hin","lit","ben","cat","bos","hrv","tha","orv","cha","mon","lzh","scn","gle","mkd","slv","frm","glg","vol","ain","jbo","tok","ina","nds","mal","tlh","roh","ltz","oss","ido","gla","mlt","sco","ast","jav","oci","ile","ota","xal","tel","sjn","nov","khm","tpi","ang","aze","tgk","tuk","chv","hsb","dsb","bod","sme","cym","mri","ksh","kmr","ewe","kab","ber","tpw","udm","lld","pms","lad","grn","mlg","xho","pnb","grc","hat","lao","npi","cor","nah","avk","mar","guj","pan","kir","myv","prg","sux","crs","ckt","bak","zlm","hil","cbk","chr","nav","lkt","enm","arq","lin","abk","pcd","rom","gsw","tam","zul","awa","wln","amh","bar","hbo","mhr","bho","mrj","ckb","osx","pfl","mgm","sna","mah","hau","kan","nog","sin","glv","dng","kal","liv","vro","apc","jdt","fur","che","haw","yor","crh","pdc","ppl","kin","shs","mnw","tet","sah","kum","ngt","nya","pus","hif","mya","moh","wol","tir","ton","lzz","oar","lug","brx","non","mww","hak","nlv","ngu","bua","aym","vec","ibo","tkl","bam","kha","ceb","lou","fuc","smo","gag","lfn","arg","umb","tyv","kjh","oji","cyo","urh","kzj","pam","srd","lmo","swg","mdf","gil","snd","tso","sot","zza","tsn","pau","som","egl","ady","asm","ori","dtp","cho","max","kam","niu","sag","ilo","kaa","fuv","nch","hoc","iba","gbm","sun","war","mvv","pap","ary","kxi","csb","pag","cos","rif","kek","krc","aii","ban","ssw","tvl","mfe","tah","bvy","bcl","hnj","nau","nst","afb","quc","min","tmw","mad","bjn","mai","cjy","got","hsn","gan","tzl","dws","ldn","afh","sgs","krl","vep","rue","tly","mic","ext","izh","sma","jam","cmo","mwl","kpv","koi","bis","ike","run","evn","ryu","mnc","aoz","otk","kas","aln","akl","yua","shy","fkv","gos","fij","thv","zgh","gcf","cay","xmf","tig","div","lij","rap","hrx","cpi","tts","gaa","tmr","iii","ltg","bzt","syc","emx","gom","chg","osp","stq","frr","fro","nys","toi","new","phn","jpa","rel","drt","chn","pli","laa","bal","hdn","hax","mik","ajp","xqa","pal","crk","mni","lut","ayl","ood","sdh","ofs","nus","kiu","diq","qxq","alt","bfz","klj","mus","srn","guc","lim","zea","shi","mnr","bom","sat","szl"]
features = Features({ 'label': ClassLabel(names=class_names), 'text': Value('string')})
num_labels = features['label'].num_classes
data_files = { "train": "train.csv", "test": "test.csv" }
sentences = load_dataset(
     "loretoparisi/tatoeba-sentences",
     data_files=data_files,
     delimiter='\t', 
     column_names=['label', 'text'],
     download_mode="force_redownload")
sentences = sentences.filter(lambda example: example['label'] is not None and example['text'] is not None)
sentences = sentences.shuffle()

from transformers import AutoTokenizer

model_name = 'microsoft/xtremedistil-l6-h256-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)

def tokenize(batch):
    tokens = tokenizer(batch['text'], padding="max_length", truncation=True, max_length=128)
    tokens['label'] = features["label"].str2int(batch['label'])
    return tokens
tokenized_datasets = sentences.map(tokenize, batched=True)

full_train_dataset = tokenized_datasets["train"]
full_eval_dataset = tokenized_datasets["test"]

import torch
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)

from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
model = model.to(device)

import numpy as np
from datasets import load_metric

metric = load_metric("accuracy")
def compute_metrics(eval_pred):
    print(eval_pred)
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

from transformers import TrainingArguments
training_args = TrainingArguments("test_trainer",
                                  per_device_train_batch_size=128, 
                                  num_train_epochs=24,learning_rate=3e-05,
                                  evaluation_strategy="epoch")
from transformers import Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=full_train_dataset,
    eval_dataset=full_eval_dataset,
    compute_metrics=compute_metrics,
)