hezarai / hezar

The all-in-one AI library for Persian, supporting a wide variety of tasks and modalities!
https://hezarai.github.io/hezar/
Apache License 2.0
823 stars 45 forks source link

TrOCR training #145

Closed aliesmaeilpour closed 5 months ago

aliesmaeilpour commented 6 months ago

Hi, I want to use your base ocr model for training a new model on my custom dataset, but your tutorial on ocr model is not working, is there any way for me to access your ocr model using hezar or training the model with transformer? I appreciate your kind reply to my question.

arxyzan commented 6 months ago

Hi @aliesmaeilpour jan, The TrOCR model really did not suit any of our requirements in terms of performance and we attained better results using CRNN, so we didn't work on it further. If you want to use our trained TrOCR model using Transformers you can save it as a transformer model like below and the rest is just Transformers stuff.

from hezar.models import Model
from transformers import VisionEncoderDecoderModel

# Save model as Transformers compatible
model = Model.load("hezarai/trocr-base-fa-v2")
model.trocr.save_pretrained("trocr-base-fa")

# Load Transformers model
transformers_trocr = VisionEncoderDecoderModel.from_pretrained("trocr-base-fa")
# Now you can finetune this model using Transformers Trainer
...

Now that you have saved the Transformers TrOCR model at trocr-base-fa you can follow this tutorial to finetune it on your own dataset: https://github.com/NielsRogge/Transformers-Tutorials/tree/master/TrOCR

aliesmaeilpour commented 6 months ago

Hi @aliesmaeilpour jan, The TrOCR model really did not suit any of our requirements in terms of performance and we attained better results using CRNN, so we didn't work on it further. If you want to use our trained TrOCR model using Transformers you can save it as a transformer model like below and the rest is just Transformers stuff.

from hezar.models import Model
from transformers import VisionEncoderDecoderModel

# Save model as Transformers compatible
model = Model.load("hezarai/trocr-base-fa-v2")
model.trocr.save_pretrained("trocr-base-fa")

# Load Transformers model
transformers_trocr = VisionEncoderDecoderModel.from_pretrained("trocr-base-fa")
# Now you can finetune this model using Transformers Trainer
...

Now that you have saved the Transformers TrOCR model at trocr-base-fa you can follow this tutorial to finetune it on your own dataset: https://github.com/NielsRogge/Transformers-Tutorials/tree/master/TrOCR

Thank you, wish you the best.

aliesmaeilpour commented 6 months ago

Hi @aliesmaeilpour jan, The TrOCR model really did not suit any of our requirements in terms of performance and we attained better results using CRNN, so we didn't work on it further. If you want to use our trained TrOCR model using Transformers you can save it as a transformer model like below and the rest is just Transformers stuff.

from hezar.models import Model
from transformers import VisionEncoderDecoderModel

# Save model as Transformers compatible
model = Model.load("hezarai/trocr-base-fa-v2")
model.trocr.save_pretrained("trocr-base-fa")

# Load Transformers model
transformers_trocr = VisionEncoderDecoderModel.from_pretrained("trocr-base-fa")
# Now you can finetune this model using Transformers Trainer
...

Now that you have saved the Transformers TrOCR model at trocr-base-fa you can follow this tutorial to finetune it on your own dataset: https://github.com/NielsRogge/Transformers-Tutorials/tree/master/TrOCR

I have other questions actually, can we fine tune your CRNN model on our own dataset with hezar?

arxyzan commented 5 months ago

Hi @aliesmaeilpour jan, sorry for the late response! Actually you can finetune the CRNN model on your own dataset using Hezar. In a nutshell, all you need to do is to implement your own dataset class by subclassing Hezar OCRDataset.

A really short sample would be have a structure like below:

from hezar.models import CRNNImage2TextConfig, CRNNImage2Text
from hezar.preprocessors import ImageProcessor
from hezar.trainer import Trainer, TrainerConfig

from hezar.data import OCRDataset, OCRDatasetConfig

class PersianOCRDataset(OCRDataset):
    def __init__(self, config: OCRDatasetConfig, split=None, **kwargs):
        super().__init__(config=config, split=split, **kwargs)

    def _load(self, split=None):
        # Load a dataframe here and make sure the split is fetched
        data = pd.read_csv(self.config.path)
        # preprocess if needed
        return data

    def __getitem__(self, index):
        path, text = self.data.iloc[index].values()
        pixel_values = self.image_processor(path, return_tensors="pt")["pixel_values"][0]
        labels = self._text_to_tensor(text)
        inputs = {
            "pixel_values": pixel_values,
            "labels": labels,
        }
        return inputs

dataset_config = OCRDatasetConfig(
    path="path/to/csv",
    text_split_type="char_split",
    text_column="label",
    images_paths_column="image_path",
    reverse_digits=True,
)

train_dataset = PersianOCRDataset(dataset_config, split="train")
eval_dataset = PersianOCRDataset(dataset_config, split="test")

model = CRNNImage2Text(
    CRNNImage2TextConfig(
        id2label=train_dataset.config.id2label,
        map2seq_in_dim=1024,
        map2seq_out_dim=96
    )
)
preprocessor = ImageProcessor(train_dataset.config.image_processor_config)

train_config = TrainerConfig(
    output_dir="crnn-plate-fa-v1",
    task="image2text",
    device="cuda",
    batch_size=8,
    num_epochs=20,
    metrics=["cer"],
    metric_for_best_model="cer"
)

trainer = Trainer(
    config=train_config,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=train_dataset.data_collator,
    preprocessor=preprocessor,
)
trainer.train()

Note that this code does not work out of the box and you have to do your own changes. Let me know if you have any other challenges doing so.

aliesmaeilpour commented 5 months ago

Hi @aliesmaeilpour jan, sorry for the late response! Actually you can finetune the CRNN model on your own dataset using Hezar. In a nutshell, all you need to do is to implement your own dataset class by subclassing Hezar OCRDataset.

A really short sample would be have a structure like below:

from hezar.models import CRNNImage2TextConfig, CRNNImage2Text
from hezar.preprocessors import ImageProcessor
from hezar.trainer import Trainer, TrainerConfig

from hezar.data import OCRDataset, OCRDatasetConfig

class PersianOCRDataset(OCRDataset):
    def __init__(self, config: OCRDatasetConfig, split=None, **kwargs):
        super().__init__(config=config, split=split, **kwargs)

    def _load(self, split=None):
        # Load a dataframe here and make sure the split is fetched
        data = pd.read_csv(self.config.path)
        # preprocess if needed
        return data

    def __getitem__(self, index):
        path, text = self.data.iloc[index].values()
        pixel_values = self.image_processor(path, return_tensors="pt")["pixel_values"][0]
        labels = self._text_to_tensor(text)
        inputs = {
            "pixel_values": pixel_values,
            "labels": labels,
        }
        return inputs

dataset_config = OCRDatasetConfig(
    path="path/to/csv",
    text_split_type="char_split",
    text_column="label",
    images_paths_column="image_path",
    reverse_digits=True,
)

train_dataset = PersianOCRDataset(dataset_config, split="train")
eval_dataset = PersianOCRDataset(dataset_config, split="test")

model = CRNNImage2Text(
    CRNNImage2TextConfig(
        id2label=train_dataset.config.id2label,
        map2seq_in_dim=1024,
        map2seq_out_dim=96
    )
)
preprocessor = ImageProcessor(train_dataset.config.image_processor_config)

train_config = TrainerConfig(
    output_dir="crnn-plate-fa-v1",
    task="image2text",
    device="cuda",
    batch_size=8,
    num_epochs=20,
    metrics=["cer"],
    metric_for_best_model="cer"
)

trainer = Trainer(
    config=train_config,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=train_dataset.data_collator,
    preprocessor=preprocessor,
)
trainer.train()

Note that this code does not work out of the box and you have to do your own changes. Let me know if you have any other challenges doing so.

Thank you for your response.