ArneBinder / pytorch-ie-hydra-template-1

PyTorch-IE Hydra Template
8 stars 1 forks source link

TransformerSpanClassificationTaskModule is not a registered name for TaskModule. #162

Closed taghizad3h closed 5 months ago

taghizad3h commented 5 months ago

When running the example mentioned in readme file In colab I encountered the following error

from dataclasses import dataclass

from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.auto import AutoPipeline
from pytorch_ie.core import AnnotationLayer, annotation_field
from pytorch_ie.documents import TextDocument

@dataclass
class ExampleDocument(TextDocument):
    entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text")

document = ExampleDocument(
    "“Making a super tasty alt-chicken wing is only half of it,” said Po Bronson, general partner at SOSV and managing director of IndieBio."
)

# see below for the long version
ner_pipeline = AutoPipeline.from_pretrained("pie/example-ner-spanclf-conll03", device=-1, num_workers=0)

ner_pipeline(document)

for entity in document.entities.predictions:
    print(f"{entity} -> {entity.label}")
---------------------------------------------------------------------------

RegistrationError                         Traceback (most recent call last)

[<ipython-input-6-266b4f95f95f>](https://localhost:8080/#) in <cell line: 19>()
     17 
     18 # see below for the long version
---> 19 ner_pipeline = AutoPipeline.from_pretrained("pie/example-ner-spanclf-conll03", device=-1, num_workers=0)
     20 
     21 ner_pipeline(document)

4 frames

[/usr/local/lib/python3.10/dist-packages/pytorch_ie/auto.py](https://localhost:8080/#) in from_pretrained(pretrained_model_name_or_path, force_download, resume_download, proxies, use_auth_token, cache_dir, local_files_only, taskmodule_kwargs, model_kwargs, device, binary_output, **kwargs)
    126         model_kwargs = model_kwargs or {}
    127 
--> 128         taskmodule = AutoTaskModule.from_pretrained(
    129             pretrained_model_name_or_path=pretrained_model_name_or_path,
    130             force_download=force_download,

[/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_validators.py](https://localhost:8080/#) in _inner_fn(*args, **kwargs)
    116             kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=has_token, kwargs=kwargs)
    117 
--> 118         return fn(*args, **kwargs)
    119 
    120     return _inner_fn  # type: ignore

[/usr/local/lib/python3.10/dist-packages/pytorch_ie/core/hf_hub_mixin.py](https://localhost:8080/#) in from_pretrained(cls, pretrained_model_name_or_path, force_download, resume_download, proxies, token, cache_dir, local_files_only, revision, **model_kwargs)
    182         model_kwargs["is_from_pretrained"] = True
    183 
--> 184         return cls._from_pretrained(
    185             model_id=str(model_id),
    186             revision=revision,

[/usr/local/lib/python3.10/dist-packages/pytorch_ie/auto.py](https://localhost:8080/#) in _from_pretrained(cls, model_id, revision, cache_dir, force_download, proxies, resume_download, local_files_only, token, map_location, strict, config, **taskmodule_kwargs)
     93         config.update(taskmodule_kwargs)
     94         class_name = config.pop(cls.config_type_key)
---> 95         clazz: Type[TaskModule] = TaskModule.by_name(class_name)
     96         taskmodule = clazz(**config)
     97         taskmodule.post_prepare()

[/usr/local/lib/python3.10/dist-packages/pytorch_ie/core/registrable.py](https://localhost:8080/#) in by_name(cls, name)
     46             return Registrable._registry[cls][name]
     47 
---> 48         raise RegistrationError(f"{name} is not a registered name for {cls.__name__}.")
     49 
     50     @classmethod

RegistrationError: TransformerSpanClassificationTaskModule is not a registered name for TaskModule.
ArneBinder commented 5 months ago

Hi @taghizad3h,

adding from pytorch_ie.taskmodules import * to the beginning of your code should fix it.

Thanks for reporting this! I will change the example accordingly.

taghizad3h commented 5 months ago

Ok. Thank you it seems that I have reported the issue in wrong repository. I was using the example form pytorch-ie main repo. Sorry for confusion 😅.

ArneBinder commented 5 months ago

ah ok, thanks for letting me know :) I created a new issue there.