sileod / tasknet

Easy multi-task learning with HuggingFace Datasets and Trainer
GNU General Public License v3.0
45 stars 4 forks source link

Unable to load saved model #7

Open deewhy26 opened 1 year ago

deewhy26 commented 1 year ago

Hello, sorry im quite new to writing issues. I trained a joint token classification and sequence classification model. To save it i used this: trainer.save_model("multi_task/") However trying to load the same model, i faced this issue model_2 = tn.load_pipeline("/kaggle/working/multi_task","intent_classification", adapt_task_embedding=True)

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <module>:1                                                                                    │
│                                                                                                  │
│ ❱ 1 model_2 =  tn.load_pipeline("/kaggle/working/multi_task","intent_classification", adapt_     │
│   2                                                                                              │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/tasknet/utils.py:198 in load_pipeline                    │
│                                                                                                  │
│   195 │   │   import tasksource                                                                  │
│   196 │   except:                                                                                │
│   197 │   │   raise ImportError("Requires tasksource.\n pip install tasksource")                 │
│ ❱ 198 │   task = tasksource.load_task(task_name, multilingual=multilingual)                      │
│   199 │                                                                                          │
│   200 │   model = AutoModelForSequenceClassification.from_pretrained(                            │
│   201 │   │   model_name, ignore_mismatched_sizes=True                                           │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/tasksource/access.py:102 in load_task                    │
│                                                                                                  │
│    99 │   query = dict_of(id, dataset_name, config_name, task_name,preprocessing_name)           │
│   100 │   query = {k:v for k,v in query.items() if v}                                            │
│   101 │   _tasks = (lmtasks if multilingual else tasks)                                          │
│ ❱ 102 │   preprocessing = load_preprocessing(_tasks, **query)                                    │
│   103 │   dataset = load_dataset(preprocessing.dataset_name, preprocessing.config_name, **load   │
│   104 │   dataset= preprocessing(dataset,max_rows, max_rows_eval)                                │
│   105 │   dataset.task_type = preprocessing.__class__.__name__                                   │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/tasksource/access.py:90 in load_preprocessing            │
│                                                                                                  │
│    87                                                                                            │
│    88 def load_preprocessing(tasks=tasks, **kwargs):                                             │
│    89 │   _tasks_df = list_tasks(multilingual=tasks==lmtasks)                                    │
│ ❱  90 │   y = _tasks_df.copy().query(dict_to_query(**kwargs)).iloc[0]                            │
│    91 │   preprocessing= copy.copy(getattr(tasks, y.preprocessing_name))                         │
│    92 │   for c in 'dataset_name','config_name':                                                 │
│    93 │   │   if not isinstance(getattr(preprocessing,c), str):                                  │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/pandas/core/indexing.py:1073 in __getitem__              │
│                                                                                                  │
│   1070 │   │   │   axis = self.axis or 0                                                         │
│   1071 │   │   │                                                                                 │
│   1072 │   │   │   maybe_callable = com.apply_if_callable(key, self.obj)                         │
│ ❱ 1073 │   │   │   return self._getitem_axis(maybe_callable, axis=axis)                          │
│   1074 │                                                                                         │
│   1075 │   def _is_scalar_access(self, key: tuple):                                              │
│   1076 │   │   raise NotImplementedError()                                                       │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/pandas/core/indexing.py:1625 in _getitem_axis            │
│                                                                                                  │
│   1622 │   │   │   │   raise TypeError("Cannot index by location index with a non-integer key")  │
│   1623 │   │   │                                                                                 │
│   1624 │   │   │   # validate the location                                                       │
│ ❱ 1625 │   │   │   self._validate_integer(key, axis)                                             │
│   1626 │   │   │                                                                                 │
│   1627 │   │   │   return self.obj._ixs(key, axis=axis)                                          │
│   1628                                                                                           │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/pandas/core/indexing.py:1557 in _validate_integer        │
│                                                                                                  │
│   1554 │   │   """                                                                               │
│   1555 │   │   len_axis = len(self.obj._get_axis(axis))                                          │
│   1556 │   │   if key >= len_axis or key < -len_axis:                                            │
│ ❱ 1557 │   │   │   raise IndexError("single positional indexer is out-of-bounds")                │
│   1558 │                                                                                         │
│   1559 │   # -------------------------------------------------------------------                 │
│   1560                                                                                           │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
IndexError: single positional indexer is out-of-bounds

I intend to load the model from this checkpoint for domain adaption How do i address this issue? Also if you could redirect me to some resources to understand adapters that would be pretty helpful (i found out from the other issue posted here). PS: Thank you for this incredible library- it saved me a lot of time

sileod commented 1 year ago

Hi ,thank you for your issue, would it be easy to reproduce the error in a colab ?

deewhy26 commented 1 year ago

Hi ,thank you for your issue, would it be easy to reproduce the error in a colab ?

Sorry i dont understand, do you want me to try this on collab? This error currently was rendered on kaggle