ThilinaRajapakse / simpletransformers

Transformers for Information Retrieval, Text Classification, NER, QA, Language Modelling, Language Generation, T5, Multi-Modal, and Conversational AI
https://simpletransformers.ai/
Apache License 2.0
4.08k stars 727 forks source link

Added compatibility for T5TokenizerFast #1547

Closed robinvandernoord closed 11 months ago

robinvandernoord commented 1 year ago

I wanted to use a T5 model that was missing an spiece.model file.

import torch
from transformers.models.t5 import T5TokenizerFast, T5Tokenizer
from simpletransformers.t5 import (
    T5Model,
    T5Args,
)

MODEL_NAME = "yhavinga/t5-v1.1-base-dutch-cased"
MODEL_TYPE = "t5"

tokenizer = T5TokenizerFast.from_pretrained(MODEL_NAME, truncate=True)

model_args = T5Args()

model = T5Model(
    MODEL_TYPE,
    MODEL_NAME,
    args=model_args,
    use_cuda=torch.cuda.is_available(),
    tokenizer=tokenizer
)

This lead to the following exception:

File /opt/conda/lib/python3.10/site-packages/simpletransformers/t5/t5_model.py:139, in T5Model.__init__(self, model_type, model_name, args, tokenizer, use_cuda, cuda_device, **kwargs)
    137     self.tokenizer = ByT5Tokenizer.from_pretrained(model_name, truncate=True)
    138 else:
--> 139     self.tokenizer = T5Tokenizer.from_pretrained(model_name, truncate=True)
    141 if self.args.dynamic_quantize:
    142     self.model = torch.quantization.quantize_dynamic(
    143         self.model, {torch.nn.Linear}, dtype=torch.qint8
    144     )

File /opt/conda/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1854, in PreTrainedTokenizerBase.from_pretrained(cls, pretrained_model_name_or_path, cache_dir, force_download, local_files_only, token, revision, *init_inputs, **kwargs)
   1851     else:
   1852         logger.info(f"loading file {file_path} from cache at {resolved_vocab_files[file_id]}")
-> 1854 return cls._from_pretrained(
   1855     resolved_vocab_files,
   1856     pretrained_model_name_or_path,
   1857     init_configuration,
   1858     *init_inputs,
   1859     token=token,
   1860     cache_dir=cache_dir,
   1861     local_files_only=local_files_only,
   1862     _commit_hash=commit_hash,
   1863     _is_local=is_local,
   1864     **kwargs,
   1865 )

File /opt/conda/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:2017, in PreTrainedTokenizerBase._from_pretrained(cls, resolved_vocab_files, pretrained_model_name_or_path, init_configuration, token, cache_dir, local_files_only, _commit_hash, _is_local, *init_inputs, **kwargs)
   2015 # Instantiate tokenizer.
   2016 try:
-> 2017     tokenizer = cls(*init_inputs, **init_kwargs)
   2018 except OSError:
   2019     raise OSError(
   2020         "Unable to load vocabulary from file. "
   2021         "Please check that the provided vocabulary is accessible and not corrupted."
   2022     )

File /opt/conda/lib/python3.10/site-packages/transformers/models/t5/tokenization_t5.py:194, in T5Tokenizer.__init__(self, vocab_file, eos_token, unk_token, pad_token, extra_ids, additional_special_tokens, sp_model_kwargs, legacy, **kwargs)
    191 self.vocab_file = vocab_file
    192 self._extra_ids = extra_ids
--> 194 self.sp_model = self.get_spm_processor()

File /opt/conda/lib/python3.10/site-packages/transformers/models/t5/tokenization_t5.py:199, in T5Tokenizer.get_spm_processor(self)
    197 tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
    198 if self.legacy:  # no dependency on protobuf
--> 199     tokenizer.Load(self.vocab_file)
    200     return tokenizer
    202 with open(self.vocab_file, "rb") as f:

File /opt/conda/lib/python3.10/site-packages/sentencepiece/__init__.py:905, in SentencePieceProcessor.Load(self, model_file, model_proto)
    903 if model_proto:
    904   return self.LoadFromSerializedProto(model_proto)
--> 905 return self.LoadFromFile(model_file)

File /opt/conda/lib/python3.10/site-packages/sentencepiece/__init__.py:310, in SentencePieceProcessor.LoadFromFile(self, arg)
    309 def LoadFromFile(self, arg):
--> 310     return _sentencepiece.SentencePieceProcessor_LoadFromFile(self, arg)

TypeError: not a string

The 'tokenizer' parameter is checked to be T5Tokenizer in the __init__ of T5Model. However, T5TokenizerFast is not a subclass of that class but it is compatible. Checking both should fix this issue.