Open hadariru opened 1 month ago
cc @molbap
Thanks for the issue @hadariru - just one note, it looks like the fine-tuning itself is working (ie if you let loss go down and don't add eval), it's the evaluation part in Trainer
that has an issue?
Seems the only way for losses
to be not accessed would be prediction_step
failing. cc @muellerzr in case you are familiar, will take a look at this soon
@molbap Yes, the evaluation part is giving me error. Training itself is working fine. I can see finetune is working okay. (I checked by running prediction on the training data)
@muellerz @molbap @hadariru I think this happens because trainer accept the case when loss is None.
when the loss is None and when you want to compute the metrics losses is not defined due to gather function for None in multi-gpu is useless. So you cannot del the losses variable since it has not been defined.
I think there are two ways to make this work
@SangbumChoi this is the model that I used
model = PaliGemmaForConditionalGeneration.from_pretrained(
object_detection_config.MODEL_ID,
torch_dtype=object_detection_config.MODEL_DTYPE,
device_map=device,
revision=object_detection_config.MODEL_REVISION,
)
I tried to backtrack the reason why loss is None.
I found out that self.label_names
and loss_without_labels
when it is evaluating is [] and False
I am not sure on what value to give or how to set label_names on trainer
changing
data_collator = partial(self.data_collator, train=False)
-> data_collator = partial(self.data_collator, train=True)
on the get_eval_dataloader
gives me this error
Traceback (most recent call last):
File "xxx", line 361, in <module>
trainer.train()
File "xxxlib/python3.11/site-packages/transformers/trainer.py", line 1885, in train
return inner_training_loop(
^^^^^^^^^^^^^^^^^^^^
File "xxxlib/python3.11/site-packages/transformers/trainer.py", line 2291, in _inner_training_loop
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
File "xxxlib/python3.11/site-packages/transformers/trainer.py", line 2721, in _maybe_log_save_evaluate
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "xxxlib/python3.11/site-packages/transformers/trainer.py", line 3572, in evaluate
output = eval_loop(
^^^^^^^^^^
File "xxxlib/python3.11/site-packages/transformers/trainer.py", line 3780, in evaluation_loop
all_preds.add(logits)
File "xxxlib/python3.11/site-packages/transformers/trainer_pt_utils.py", line 326, in add
self.tensors = nested_concat(self.tensors, tensors, padding_index=self.padding_index)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "xxxlib/python3.11/site-packages/transformers/trainer_pt_utils.py", line 138, in nested_concat
return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "xxxlib/python3.11/site-packages/transformers/trainer_pt_utils.py", line 138, in <genexpr>
return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "xxxlib/python3.11/site-packages/transformers/trainer_pt_utils.py", line 138, in nested_concat
return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "xxxlib/python3.11/site-packages/transformers/trainer_pt_utils.py", line 138, in <genexpr>
return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "xxxlib/python3.11/site-packages/transformers/trainer_pt_utils.py", line 138, in nested_concat
return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "xxxlib/python3.11/site-packages/transformers/trainer_pt_utils.py", line 138, in <genexpr>
return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "xxxlib/python3.11/site-packages/transformers/trainer_pt_utils.py", line 140, in nested_concat
return torch_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "xxxlib/python3.11/site-packages/transformers/trainer_pt_utils.py", line 99, in torch_pad_and_concatenate
return torch.cat((tensor1, tensor2), dim=0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 454 but got size 482 for tensor number 1 in the list.
0%| | 10/24240 [00:27<18:35:03, 2.76s/it]
@hadariru
I found out that self.label_names and loss_without_labels when it is evaluating is [] and False
usually label_names for bounding box should be 'labels' but it depends on your dataset.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
System Info
transformers
version: 4.41.2Who can help?
@muellerzr @SunMarc @amyeroberts
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Adding this 3 extra arguments on step 4 causes evaluation to be performed.
eval_dataloader is override with this
the data_collator is adapted from step 2
Error logs:
Expected behavior
No error on evaluation (losses should exist, I think)