JonasGeiping / cramming

Cramming the training of a (BERT-type) language model into limited compute.
MIT License
1.29k stars 100 forks source link

Finetuning for token classification #40

Closed druskacik closed 8 months ago

druskacik commented 8 months ago

I'd like to fine-tune this model for token classification task. As suggested in #35 , instantiating from AutoModelForTokenClassification should work. However, I see an error.

import cramming
from transformers import AutoTokenizer, AutoModelForTokenClassification

model  = AutoModelForTokenClassification.from_pretrained("JonasGeiping/crammed-bert", num_labels=3)

>>> ---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[46], line 1
----> 1 model  = AutoModelForTokenClassification.from_pretrained("JonasGeiping/crammed-bert", num_labels=3)

File ~\.conda\envs\product_scanner\lib\site-packages\transformers\models\auto\auto_factory.py:566, in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    564 elif type(config) in cls._model_mapping.keys():
    565     model_class = _get_model_class(config, cls._model_mapping)
--> 566     return model_class.from_pretrained(
    567         pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
    568     )
    569 raise ValueError(
    570     f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
    571     f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
    572 )

File ~\.conda\envs\product_scanner\lib\site-packages\transformers\modeling_utils.py:3462, in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)
   3456 config = cls._autoset_attn_implementation(
   3457     config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map
   3458 )
   3460 with ContextManagers(init_contexts):
   3461     # Let's make sure we don't run the init function of buffer modules
-> 3462     model = cls(config, *model_args, **model_kwargs)
   3464 # make sure we use the model's config since the __init__ call might have copied it
   3465 config = model.config

File ~\.conda\envs\product_scanner\lib\site-packages\cramming\architectures\crammed_bert.py:396, in ScriptableLMForTokenClassification.__init__(self, config)
    393 self.cfg = OmegaConf.create(config.arch)
    395 self.encoder = ScriptableLM(config)
--> 396 self.head = torch.nn.Linear(self.cfg.classification_head.head_dim, self.num_labels)
    398 self.problem_type = None
    399 self._init_weights()

File ~\.conda\envs\product_scanner\lib\site-packages\torch\nn\modules\module.py:1614, in Module.__getattr__(self, name)
   1612     if name in modules:
   1613         return modules[name]
-> 1614 raise AttributeError("'{}' object has no attribute '{}'".format(
   1615     type(self).__name__, name))

AttributeError: 'ScriptableLMForTokenClassification' object has no attribute 'num_labels'

Versions:

transformers==4.36.2
torch==2.0.1
JonasGeiping commented 8 months ago

Hi, did you see https://github.com/JonasGeiping/cramming/issues/34#issuecomment-1745372529? I think this might be a related problem.

druskacik commented 8 months ago

Okay, that solved the issue. I also had to change the head_dim value because of the matrix multiplication error. Thanks.

JonasGeiping commented 8 months ago

Ok, I'm glad!

For anyone reading this in the future, I'm also definitely accepting PR's to fix this problem, I just haven't had time to do it myself.