urchade / GLiNER

Generalist and Lightweight Model for Named Entity Recognition (Extract any entity types from texts) @ NAACL 2024
https://arxiv.org/abs/2311.08526
Apache License 2.0
1.43k stars 124 forks source link

TypeError: DebertaV2Model.__init__() got an unexpected keyword argument 'subtoken_pooling' #26

Closed rxavier closed 7 months ago

rxavier commented 8 months ago

I'm trying to integrate this model into an existing app that hosts other zero shot models (not for NER though, but there's a lot of shared dependencies).

I went ahead and installed GLiNER. This is my requirements.txt after doing so (python 3.10.4):

Package                               Version
------------------------------------- -----------
aniso8601                             9.0.1
annotated-types                       0.6.0
anyio                                 4.3.0
apolo                                 1.2.1
app                                   0.0.0
artifacts-toolkit                     1.9.3
atomicwrites                          1.4.1
attrs                                 23.2.0
beautifulsoup4                        4.12.3
blinker                               1.7.0
boto3                                 1.34.54
botocore                              1.34.54
bpemb                                 0.3.5
certifi                               2023.11.17
cffi                                  1.16.0
charset-normalizer                    3.3.2
click                                 8.1.7
cloudpickle                           3.0.0
cmake                                 3.28.3
contextlib2                           21.6.0
contourpy                             1.2.0
coverage                              7.4.0
cryptography                          42.0.5
cycler                                0.12.1
datadog                               0.40.1
decorator                             5.1.1
defusedxml                            0.7.1
Deprecated                            1.2.14
exceptiongroup                        1.2.0
filelock                              3.13.1
flair                                 0.6.1.post1
Flask                                 2.3.3
Flask-RESTful                         0.3.10
fonttools                             4.50.0
frozendict                            2.1.1
fsspec                                2024.2.0
ftfy                                  6.1.3
future                                1.0.0
gdown                                 5.1.0
gensim                                4.3.2
gevent                                22.10.1
gliner                                0.1.2
greenlet                              1.1.3.post0
gunicorn                              19.9.0
h11                                   0.14.0
httpcore                              0.17.3
httpx                                 0.24.1
huggingface-hub                       0.21.3
hyperopt                              0.2.7
idna                                  3.6
image-downloader                      1.3.4
importlib-metadata                    6.11.0
iniconfig                             2.0.0
itsdangerous                          2.1.2
Janome                                0.5.0
jeepney                               0.8.0
Jinja2                                3.1.3
jmespath                              1.0.1
joblib                                1.3.2
keyring                               21.8.0
kiwisolver                            1.4.5
konoha                                4.6.3
langdetect                            1.0.9
lit                                   17.0.6
lxml                                  5.1.0
MarkupSafe                            2.1.5
marshmallow                           3.21.0
matplotlib                            3.8.3
more-itertools                        10.2.0
mpld3                                 0.3
mpmath                                1.3.0
nest-asyncio                          1.6.0
networkx                              3.2.1
newrelic                              8.11.0
numpy                                 1.26.4
nvidia-cublas-cu11                    11.10.3.66
nvidia-cuda-cupti-cu11                11.7.101
nvidia-cuda-nvrtc-cu11                11.7.99
nvidia-cuda-runtime-cu11              11.7.99
nvidia-cudnn-cu11                     8.5.0.96
nvidia-cufft-cu11                     10.9.0.58
nvidia-curand-cu11                    10.2.10.91
nvidia-cusolver-cu11                  11.4.0.1
nvidia-cusparse-cu11                  11.7.4.91
nvidia-nccl-cu11                      2.14.3
nvidia-nvtx-cu11                      11.7.91
open-clip-torch                       2.24.0
opentelemetry-api                     1.23.0
opentelemetry-instrumentation         0.41b0
opentelemetry-instrumentation-httpx   0.41b0
opentelemetry-instrumentation-urllib3 0.41b0
opentelemetry-semantic-conventions    0.41b0
opentelemetry-util-http               0.41b0
overrides                             3.1.0
packaging                             22.0
pandas                                2.2.1
pillow                                10.2.0
pip                                   23.2.1
pluggy                                0.13.1
protobuf                              4.25.3
py                                    1.11.0
py4j                                  0.10.9.7
pyarrow                               15.0.0
pycparser                             2.21
pycryptodomex                         3.20.0
pycurl                                7.45.3
pydantic                              2.6.3
pydantic_core                         2.16.3
PyJWT                                 2.8.0
Pympler                               0.9
pyOpenSSL                             24.0.0
pyparsing                             3.1.2
PySocks                               1.7.1
pytest                                7.4.3
pytest-cov                            4.1.0
pytest-mock                           3.12.0
python-dateutil                       2.8.2
python-slugify                        7.0.0
pytz                                  2024.1
PyYAML                                6.0.1
regex                                 2023.12.25
requests                              2.31.0
s3transfer                            0.10.0
safetensors                           0.4.2
schema                                0.7.5
scikit-learn                          1.4.1.post1
scipy                                 1.12.0
SecretStorage                         3.3.3
segtok                                1.5.11
sentencepiece                         0.2.0
seqeval                               1.2.2
setuptools                            68.0.0
six                                   1.16.0
smart-open                            7.0.1
sniffio                               1.3.1
soupsieve                             2.5
sqlitedict                            2.1.0
sympy                                 1.12
tabulate                              0.9.0
text-unidecode                        1.3
threadpoolctl                         3.3.0
timm                                  0.9.16
tokenizers                            0.15.2
tomli                                 2.0.1
torch                                 2.0.1
torchvision                           0.15.2
tqdm                                  4.66.2
transformers                          4.38.2
triton                                2.0.0
typing_extensions                     4.10.0
tzdata                                2024.1
ubatch                                1.0.0
urllib3                               1.26.18
urllib3-secure-extra                  0.1.0
uWSGI                                 2.0.21
wcwidth                               0.2.13
Werkzeug                              3.0.1
wheel                                 0.41.1
wrapt                                 1.16.0
zipp                                  3.17.0
zope.event                            5.0
zope.interface                        6.0

And this is the error I'm getting when running the example in this repo's README, with both base and multi models.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[3], line 3
      1 from gliner import GLiNER
----> 3 model = GLiNER.from_pretrained("urchade/gliner_base")
      5 text = """
      6 Cristiano Ronaldo dos Santos Aveiro (Portuguese pronunciation: [kɾiʃˈtjɐnu ʁɔˈnaldu]; born 5 February 1985) is a Portuguese professional footballer who plays as a forward for and captains both Saudi Pro League club Al Nassr and the Portugal national team. Widely regarded as one of the greatest players of all time, Ronaldo has won five Ballon d'Or awards,[note 3] a record three UEFA Men's Player of the Year Awards, and four European Golden Shoes, the most by a European player. He has won 33 trophies in his career, including seven league titles, five UEFA Champions Leagues, the UEFA European Championship and the UEFA Nations League. Ronaldo holds the records for most appearances (183), goals (140) and assists (42) in the Champions League, goals in the European Championship (14), international goals (128) and international appearances (205). He is one of the few players to have made over 1,200 professional career appearances, the most by an outfield player, and has scored over 850 official senior career goals for club and country, making him the top goalscorer of all time.
      7 """
      9 labels = ["person", "award", "date", "competitions", "teams"]

File ~/.cache/pypoetry/virtualenvs/app-j_VG4LQw-py3.10/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py:118, in validate_hf_hub_args.<locals>._inner_fn(*args, **kwargs)
    115 if check_use_auth_token:
    116     kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=has_token, kwargs=kwargs)
--> 118 return fn(*args, **kwargs)

File ~/.cache/pypoetry/virtualenvs/app-j_VG4LQw-py3.10/lib/python3.10/site-packages/huggingface_hub/hub_mixin.py:277, in ModelHubMixin.from_pretrained(cls, pretrained_model_name_or_path, force_download, resume_download, proxies, token, cache_dir, local_files_only, revision, **model_kwargs)
    273     elif any(param.kind == inspect.Parameter.VAR_KEYWORD for param in init_parameters.values()):
    274         # If __init__ accepts **kwargs, let's forward the config as well (as a dict)
    275         model_kwargs["config"] = config
--> 277 instance = cls._from_pretrained(
    278     model_id=str(model_id),
    279     revision=revision,
    280     cache_dir=cache_dir,
    281     force_download=force_download,
    282     proxies=proxies,
    283     resume_download=resume_download,
    284     local_files_only=local_files_only,
    285     token=token,
    286     **model_kwargs,
    287 )
    289 # Implicitly set the config as instance attribute if not already set by the class
    290 # This way `config` will be available when calling `save_pretrained` or `push_to_hub`.
    291 if config is not None and instance.config is None:

File ~/.cache/pypoetry/virtualenvs/app-j_VG4LQw-py3.10/lib/python3.10/site-packages/gliner/model.py:405, in GLiNER._from_pretrained(cls, model_id, revision, cache_dir, force_download, proxies, resume_download, local_files_only, token, map_location, strict, **model_kwargs)
    393     config_file = hf_hub_download(
    394         repo_id=model_id,
    395         filename="gliner_config.json",
   (...)
    402         local_files_only=local_files_only,
    403     )
    404 config = load_config_as_namespace(config_file)
--> 405 model = cls(config)
    406 state_dict = torch.load(model_file, map_location=torch.device(map_location))
    407 model.load_state_dict(state_dict, strict=strict, assign=True)

File ~/.cache/pypoetry/virtualenvs/app-j_VG4LQw-py3.10/lib/python3.10/site-packages/gliner/model.py:32, in GLiNER.__init__(self, config)
     29 self.sep_token = "<<SEP>>"
     31 # usually a pretrained bidirectional transformer, returns first subtoken representation
---> 32 self.token_rep_layer = TokenRepLayer(model_name=config.model_name, fine_tune=config.fine_tune,
     33                                      subtoken_pooling=config.subtoken_pooling, hidden_size=config.hidden_size,
     34                                      add_tokens=[self.entity_token, self.sep_token])
     36 # hierarchical representation of tokens (zaratiana et al, 2022)
     37 # https://arxiv.org/pdf/2203.14710.pdf
     38 self.rnn = LstmSeq2SeqEncoder(
     39     input_size=config.hidden_size,
     40     hidden_size=config.hidden_size // 2,
     41     num_layers=1,
     42     bidirectional=True,
     43 )

File ~/.cache/pypoetry/virtualenvs/app-j_VG4LQw-py3.10/lib/python3.10/site-packages/gliner/modules/token_rep.py:20, in TokenRepLayer.__init__(self, model_name, fine_tune, subtoken_pooling, hidden_size, add_tokens)
     14 def __init__(self, model_name: str = "bert-base-cased", fine_tune: bool = True, subtoken_pooling: str = "first",
     15              hidden_size: int = 768,
     16              add_tokens=["[SEP]", "[ENT]"]
     17              ):
     18     super().__init__()
---> 20     self.bert_layer = TransformerWordEmbeddings(
     21         model_name,
     22         fine_tune=fine_tune,
     23         subtoken_pooling=subtoken_pooling,
     24         allow_long_sentences=True
     25     )
     27     # add tokens to vocabulary
     28     self.bert_layer.tokenizer.add_tokens(add_tokens)

File ~/.cache/pypoetry/virtualenvs/app-j_VG4LQw-py3.10/lib/python3.10/site-packages/flair/embeddings/token.py:802, in TransformerWordEmbeddings.__init__(self, model, layers, pooling_operation, batch_size, use_scalar_mix, fine_tune, allow_long_sentences, **kwargs)
    800 self.tokenizer = AutoTokenizer.from_pretrained(model, **kwargs)
    801 config = AutoConfig.from_pretrained(model, output_hidden_states=True, **kwargs)
--> 802 self.model = AutoModel.from_pretrained(model, config=config, **kwargs)
    804 self.allow_long_sentences = allow_long_sentences
    806 if allow_long_sentences:

File ~/.cache/pypoetry/virtualenvs/app-j_VG4LQw-py3.10/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py:561, in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    559 elif type(config) in cls._model_mapping.keys():
    560     model_class = _get_model_class(config, cls._model_mapping)
--> 561     return model_class.from_pretrained(
    562         pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
    563     )
    564 raise ValueError(
    565     f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
    566     f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
    567 )

File ~/.cache/pypoetry/virtualenvs/app-j_VG4LQw-py3.10/lib/python3.10/site-packages/transformers/modeling_utils.py:3375, in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)
   3369 config = cls._autoset_attn_implementation(
   3370     config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map
   3371 )
   3373 with ContextManagers(init_contexts):
   3374     # Let's make sure we don't run the init function of buffer modules
-> 3375     model = cls(config, *model_args, **model_kwargs)
   3377 # make sure we use the model's config since the __init__ call might have copied it
   3378 config = model.config

TypeError: DebertaV2Model.__init__() got an unexpected keyword argument 'subtoken_pooling'
urchade commented 8 months ago

I think you should upgrade the version of flair. Try version 0.13.1

rxavier commented 8 months ago

Thanks, that fixed that error but now

TypeError: Module.load_state_dict() got an unexpected keyword argument 'assign'

I think the assign kwarg was added to torch 2.1.0 and I'm using 2.0.1. Would it be worth it to specify a minimum torch version in GLiNER's reqs or set assign=False which is the default in torch?

urchade commented 8 months ago

same, you can remove "assign" in the code

rxavier commented 8 months ago

I'm not sure whether you're suggesting I edit it for my use case or that I submit a PR with the change.

urchade commented 8 months ago

edit for you use

this is due to mismatch in version, some people encounter the same error

rxavier commented 8 months ago

In that case, since GLiNER does not work with torch < 2.1.0, shouldn't this be specified in the requirements?

urchade commented 8 months ago

python = "^3.8.0" click = "^8.0.1" torch = ">=2.0.0" transformers = "^4.38.2" huggingface-hub = "^0.21.4" flair = "^0.13.1" seqeval = "^1.2.2" tqdm = "^4.66.2"

rxavier commented 8 months ago

I don't think those requirements are anywhere in this repo, so it's not something that will be checked when installing GLiNER. In fact, when I first installed it pulled flair==0.6.1.post1

In any case, by using assign within load_state_dict you're essentially requiring torch >= 2.1.0 which is higher even.

urchade commented 7 months ago

@rxavier do you still have the issue ? try install newer version of gliner