hezarai / hezar

The all-in-one AI library for Persian, supporting a wide variety of tasks and modalities!
https://hezarai.github.io/hezar/
Apache License 2.0
843 stars 45 forks source link

ocr training #134

Closed mehrdadgh93 closed 8 months ago

mehrdadgh93 commented 9 months ago

Hi @arxyzan , Thanks for your amazing work! I just need to train my ocr model with my generated data. I generated data but I can not to find your training code for ocr. Could you please provide me with this code?

arxyzan commented 9 months ago

Hi @mehrdadgh93, thanks for your feedback! As you mentioned that you have your own data, you might need to write your custom Dataset class either in PyTorch or Hezar (I recommend write a Hezar compatible dataset class which is actually a wrapped PyTorch Dataset class with more features and makes your training easier with Hezar's Trainer) The closest example as of now is our license plate ocr training script at https://github.com/hezarai/hezar/blob/main/examples/train/train_ocr_alpr_example.py Keep in mind that this script assumes that your dataset is hosted on the hub and is compatible with Hezar.

I'll write a full notebook for training a custom OCR model on custom datasets in this weekend.

arxyzan commented 9 months ago

Hello @mehrdadgh93, Just wanted to let you know that I pushed a preview of our original ParsynthOCR-4M dataset named ParsynthOCR-200K which has 200K samples in contrast to the original 4 millions samples that we haven't published yet (due to size). You can check this dataset here: https://huggingface.co/datasets/hezarai/parsynth-ocr-200k

I pushed this dataset as a preview to showcase a training example on this data using CRNN model.

The example file is this: https://github.com/hezarai/hezar/blob/main/examples/train/train_ocr.py

Currently this script only works for main branch right now and it'll be available in the next stable release of Hezar.

If you want to try it out just make sure you install Hezar from source like below:

git clone https://github.com/hezarai/hezar.git
cd hezar
pip install ".[vision]"

In case of your problem, you have two ways:

  1. Upload your dataset to the Hub like the one in Hezar
  2. Create a custom PyTorch dataset class and the rest should work fine.

If you need additional help with this, just let me know in this issue.

mehrdadgh93 commented 9 months ago

Thanks @arxyzan for your response! I tried to use your training code but I faced error:

from datasets import load_dataset ModuleNotFoundError: No module named 'datasets'

I Think some parts of code in directory of "hezar/data/dataset/" are missed!

arxyzan commented 9 months ago

@mehrdadgh93 You're welcome. How did you install Hezar? This error is due to not having the package datasets installed which should be installed when using Hezar out of the box.

arxyzan commented 9 months ago

@mehrdadgh93 I'm glad it's solved. If you're trying to implement a torch Dataset I highly recommend also looking into the codes in hezar.data.datasets.ocr_dataset the also implements a torch dataset class for OCR. I think the only part you have to actually reimplement is the _load() method. In this method you just need to provide a dataframe of your data with the columns image_path and label. A sample to work on would be like below:

from hezar.data import OCRDataset, OCRDatasetConfig

class PersianOCRDataset(OCRDataset):
    def __init__(self, config: OCRDatasetConfig, split=None, **kwargs):
        super().__init__(config=config, split=split, **kwargs)

    def _load(self, split=None):
        # Load a dataframe here and make sure the split is fetched
        data = pd.read_csv(self.config.path)
        # preprocess if needed
        return data

    def __getitem__(self, index):
        path, text = self.data.iloc[index].values()
        pixel_values = self.image_processor(path, return_tensors="pt")["pixel_values"][0]
        labels = self._text_to_tensor(text)
        inputs = {
            "pixel_values": pixel_values,
            "labels": labels,
        }
        return inputs

dataset_config = OCRDatasetConfig(
    path="path/to/csv",
    text_split_type="char_split",
    text_column="label",
    images_paths_column="image_path",
    reverse_digits=True,
)

train_dataset = PersianOCRDataset(dataset_config, split="train")
eval_dataset = PersianOCRDataset(dataset_config, split="test")
mehrdadgh93 commented 9 months ago

Thanks @arxyzan jan! when I run your code with no change in any config, I faced error too. logits = outputs["logits"].detach().cpu().numpy() AttributeError: 'tuple' object has no attribute 'detach'

and I changed the code to: logits = outputs["logits"][0].detach().cpu().numpy()

If changing the code as mentioned above leads to another error, it's possible that the model outputs multiple tensors, and you need to handle them appropriately. You may need to check the structure of the outputs tuple and select the correct tensor for further processing.

arxyzan commented 9 months ago

Oh I see, this error originally occured last week and was fixed imediately, I think you need to install the latest version of Hezar to fix this.

git clone https://github.com/hezarai/hezar.git
cd hezar
pip install ".[all]"

Or if you have cloned the repo already, you can just pull the latest changes and reinstall.

git pull origin main
pip install ".[all]"