Closed giuseppesalvi closed 1 year ago
Facing the same issue. Any resolution around?
@giuseppesalvi
try Following function. Apologies for code formatting.
class Pix2Struct(pl.LightningModule): def init(self, config, processor, model): super().init() self.config = config self.processor = processor self.model = model
def training_step(self, batch, batch_idx):
encoding, _ = batch
outputs = self.model(**encoding)
loss = outputs.loss
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx, dataset_idx=0):
encoding, answers = batch
flattened_patches, attention_mask = encoding["flattened_patches"], encoding["attention_mask"]
batch_size = flattened_patches.shape[0]
# we feed the prompt to the model
decoder_input_ids = torch.full((batch_size, 1), self.model.config.text_config.decoder_start_token_id, device=self.device)
outputs = self.model.generate(flattened_patches=flattened_patches,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
max_length=512,
pad_token_id=self.processor.tokenizer.pad_token_id,
eos_token_id=self.processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,)
predictions = []
for seq in self.processor.tokenizer.batch_decode(outputs.sequences):
seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
# seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token
predictions.append(seq)
scores = []
for pred, answer in zip(predictions, answers):
# pred = re.sub(r"(?:(?<=>) | (?=", "", answer, count=1)
answer = answer.replace(self.processor.tokenizer.eos_token, "")
scores.append(edit_distance(pred, answer) / max(len(pred), len(answer)))
if self.config.get("verbose", False) and len(scores) == 1:
print(f"Prediction: {pred}")
print(f" Answer: {answer}")
print(f" Normed ED: {scores[0]}")
self.log("val_edit_distance", np.mean(scores))
return scores
def configure_optimizers(self):
# you could also add a learning rate scheduler if you want
optimizer = torch.optim.Adam(self.parameters(), lr=self.config.get("lr"))
return optimizer
def train_dataloader(self):
return train_dataloader
def val_dataloader(self):
return val_dataloader
@khadkechetan
I tried your code, and everything worked fine. However, I started to have doubts about the scheduler and optimizer used in the original notebook.
To investigate further, I decided to try the original code but with a different optimizer and it worked without any issues. This suggests that the problem might be related to the Adafactor optimizer specifically.
I changed only this line in the original notebook:
#optimizer = Adafactor(self.parameters(), scale_parameter=False, relative_step=False, lr=self.config.get("lr"), weight_decay=1e-05)
optimizer = torch.optim.Adam(self.parameters(), lr=self.config.get("lr"))
Thanks, actually you might be better of just using AdamW or Adam.
I'll update the notebook when I have the time
I tried to launch your notebook for Pix2Struct finetuning on Cord Dataset on Google Colab and got the following error during the trainer.fit(pl_module) execution.
Error:
When I tried the same notebook a couple of weeks ago everything worked fine, did something change? Does it have something to do with pytorch 2.0, which is by default the version used in colab now?
Thanks.