NielsRogge / Transformers-Tutorials

This repository contains demos I made with the Transformers library by HuggingFace.
MIT License
9.56k stars 1.46k forks source link

Fine tune TrOCR on IAM Handwriting Database using Seq2SeqTrainer #412

Open johnlockejrr opened 7 months ago

johnlockejrr commented 7 months ago

Seems the IAM dataset is not public anymore, any other location?

Trying to download, output:

<Error>
<script id="argent-x-extension" data-extension-id="dlcobpjiigpikoobohmabehhmhfoodbb"/>
<Code>PublicAccessNotPermitted</Code>
<Message>Public access is not permitted on this storage account. RequestId:954279e9-d01e-0066-427a-91a772000000 Time:2024-04-18T10:26:49.5054310Z</Message>
</Error>
johnlockejrr commented 7 months ago

Managed to get the data from the origin, now another problem, I follow your example and I get no model saved... do I do anything wrong?

import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset
from PIL import Image
from transformers import TrOCRProcessor
from transformers import VisionEncoderDecoderModel
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from datasets import load_metric
from transformers import default_data_collator
from huggingface_hub import login

df = pd.read_fwf('./IAM/gt_test.txt', header=None)
df.rename(columns={0: "file_name", 1: "text"}, inplace=True)
print(df.head())

train_df, test_df = train_test_split(df, test_size=0.2)
# we reset the indices to start from zero
train_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)

class IAMDataset(Dataset):
    def __init__(self, root_dir, df, processor, max_target_length=128):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # get file name + text
        file_name = self.df['file_name'][idx]
        text = self.df['text'][idx]
        # prepare image (i.e. resize + normalize)
        image = Image.open(self.root_dir + file_name).convert("RGB")
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        # add labels (input_ids) by encoding the text
        labels = self.processor.tokenizer(text,
                                          padding="max_length",
                                          max_length=self.max_target_length).input_ids
        # important: make sure that PAD tokens are ignored by the loss function
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding

processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
train_dataset = IAMDataset(root_dir='./IAM/',
                           df=train_df,
                           processor=processor)
eval_dataset = IAMDataset(root_dir='./IAM/',
                           df=test_df,
                           processor=processor)

print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(eval_dataset))

encoding = train_dataset[0]
for k,v in encoding.items():
  print(k, v.shape)

#labels = encoding['labels']
#labels[labels == -100] = processor.tokenizer.pad_token_id
#label_str = processor.decode(labels, skip_special_tokens=True)
#print(label_str)

model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-stage1")

# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size

# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    fp16=True,
    output_dir="./iam-train",
    overwrite_output_dir=True,
    learning_rate=2e-5,
    weight_decay=0.01,
    num_train_epochs=5,
    push_to_hub=True,
    hub_token="hf_XXXXXXXXXXXXXXXXXXXXXXXXX",
    logging_steps=2,
    save_steps=1000,
    eval_steps=200,
    load_best_model_at_end=True,
    metric_for_best_model="cer",
)

cer_metric = load_metric("cer")

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return {"cer": cer}

# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor.feature_extractor,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=default_data_collator,
)

trainer.train()
(TrOCR-py3.10) incognito@DESKTOP-NHKR7QL:~/TrOCR-py3.10$ ls -al iam-train/
total 8
drwxr-xr-x 2 incognito incognito 4096 Apr 18 13:34 .
drwxr-xr-x 9 incognito incognito 4096 Apr 18 14:58 ..
NielsRogge commented 7 months ago

If you provide the save_steps argument, then the model should be saved automatically to output_dir every save_steps (since save_strategy="steps" by default).

johnlockejrr commented 7 months ago

I did a:

processor.save_pretrained('./iam-train')
model.save_pretrained('./iam-train')

And it saved... the old model?

Anyway, I would like to save the best model. I think the save doesn’t care about the best model, so will just save every save_steps regardless of which step had the better loss. Am I wrong? Should I evaluate by epoch?

NielsRogge commented 7 months ago

https://discuss.huggingface.co/t/save-only-best-model-in-trainer/8442

johnlockejrr commented 7 months ago

Thank you so much!

Side question: do you have any scripts/docs how to train a foreign language TrOCR model? I mean here by foreign language especially Hebrew.

NielsRogge commented 7 months ago

Refer to this thread: https://github.com/huggingface/transformers/issues/18163

johnlockejrr commented 7 months ago

Sorry for disturbing, I'm a novice in BERT... untill now I worked only with Kraken OCR, still neural networks but a little different. Should I give a go to this code? I want to train Hebrew/Samaritan manuscripts recognition.

import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset
from PIL import Image
from transformers import TrOCRProcessor
from transformers import VisionEncoderDecoderModel
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from datasets import load_metric
from transformers import default_data_collator

df = pd.read_fwf('./SAM/gt_test.txt', header=None)
df.rename(columns={0: "file_name", 1: "text"}, inplace=True)
print(df.head())

train_df, test_df = train_test_split(df, test_size=0.1)
# we reset the indices to start from zero
train_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)

class SAMDataset(Dataset):
    def __init__(self, root_dir, df, processor, max_target_length=128):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # get file name + text
        file_name = self.df['file_name'][idx]
        text = self.df['text'][idx]
        # prepare image (i.e. resize + normalize)
        image = Image.open(self.root_dir + file_name).convert("RGB")
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        # add labels (input_ids) by encoding the text
        labels = self.processor.tokenizer(text,
                                          padding="max_length",
                                          max_length=self.max_target_length).input_ids
        # important: make sure that PAD tokens are ignored by the loss function
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding

processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")

train_dataset = SAMDataset(root_dir='./SAM/',
                           df=train_df,
                           processor=processor)
eval_dataset = SAMDataset(root_dir='./SAM/',
                           df=test_df,
                           processor=processor)

print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(eval_dataset))

encoding = train_dataset[0]
for k,v in encoding.items():
  print(k, v.shape)

#labels = encoding['labels']
#labels[labels == -100] = processor.tokenizer.pad_token_id
#label_str = processor.decode(labels, skip_special_tokens=True)
#print(label_str)

#model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-stage1")

encoder_checkpoint = "google/vit-base-patch16-224-in21k"
decoder_checkpoint = "imvladikon/alephbertgimmel-base-512"
#
# load a fine-tuned image captioning model and corresponding tokenizer and image processor
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
   encoder_checkpoint, decoder_checkpoint
).to("cuda")

# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size

# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    fp16=True,
    output_dir=f"{encoder_checkpoint}-ft-sam-v1",
    overwrite_output_dir=True,
    learning_rate=2e-5,
    weight_decay=0.01,
    num_train_epochs=5,
    push_to_hub=True,
    hub_token="hf_XXXXXXXXXXXXXXXXXXXXXXXX",
    logging_steps=2,
    save_steps=1000,
    eval_steps=200,
    save_strategy="steps",
    load_best_model_at_end=True,
    metric_for_best_model="cer",
)

cer_metric = load_metric("cer")

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return {"cer": cer}

# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor.feature_extractor,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=default_data_collator,
)

trainer.train()
trainer.save_model()
trainer.push_to_hub()
johnlockejrr commented 7 months ago

ValueError: Input image size (384*384) doesn't match model (224*224) Any idea where should I pass interpolate_pos_encoding=True? Or is there a non-Vit model that could work in my case? Thanks!

johnlockejrr commented 7 months ago

Finally I trained it on google/vit-base-patch16-384, after finishing it recognizes gibberish not even already trained images... in Hebrew, though, as I trained it, but gibberish...

NielsRogge commented 7 months ago

I'd recommend starting with 5 training examples and see if the model is able to overfit them

johnlockejrr commented 7 months ago

Ok, I'll do that! Should I keep google/vit-base-patch16-384 or use "google/vit-base-patch16-224-in21k" with interpolate_pos_encoding=True? In general, the script seems ok? Thank you so much!

johnlockejrr commented 7 months ago

Trained with 16 samples:

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
{'loss': 16.0417, 'grad_norm': nan, 'learning_rate': 2e-05, 'epoch': 1.0}
{'loss': 15.6389, 'grad_norm': nan, 'learning_rate': 2e-05, 'epoch': 2.0}
{'loss': 15.0141, 'grad_norm': 173.44268798828125, 'learning_rate': 1.8e-05, 'epoch': 3.0}
{'loss': 11.8002, 'grad_norm': 53.65483856201172, 'learning_rate': 1.6000000000000003e-05, 'epoch': 4.0}
{'loss': 10.464, 'grad_norm': 53.46924591064453, 'learning_rate': 1.4e-05, 'epoch': 5.0}
{'loss': 9.4223, 'grad_norm': 38.62109375, 'learning_rate': 1.2e-05, 'epoch': 6.0}
{'loss': 8.9751, 'grad_norm': 27.501571655273438, 'learning_rate': 1e-05, 'epoch': 7.0}
{'loss': 8.7579, 'grad_norm': 20.1580867767334, 'learning_rate': 8.000000000000001e-06, 'epoch': 8.0}
{'loss': 8.4512, 'grad_norm': 18.931493759155273, 'learning_rate': 6e-06, 'epoch': 9.0}
{'loss': 8.241, 'grad_norm': 20.91577911376953, 'learning_rate': 4.000000000000001e-06, 'epoch': 10.0}
{'train_runtime': 12.2611, 'train_samples_per_second': 11.418, 'train_steps_per_second': 1.631, 'train_loss': 11.280629634857178, 'epoch': 10.0}

Test output:

(huggingface-source-py3.10) incognito@DESKTOP-NHKR7QL:~/TrOCR-py3.10$ python test_ocr.py LINES/sam_gt/2.4.jpg
/home/incognito/huggingface-source-py3.10/lib/python3.10/site-packages/transformers/generation/utils.py:1252: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use and modify the model generation configuration (see https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )
  warnings.warn(
 � � �� ��� � �ת� �ת ��ת���ת � �ב ��ב� �ב��ב � � �� �תת�תת �תב �ת� ���� ����

Even worse.

NielsRogge commented 7 months ago

That means there's a bug in data prepatation/hyperparameter settings/model configuration.

I recommend this guide for debugging: https://karpathy.github.io/2019/04/25/recipe/