GabrielDornelles / pytorch-ocr

Simple Pytorch framework to train OCRs. Supports CRNNs, Attention, CTC and Cross Entropy Loss.
MIT License
70 stars 16 forks source link

Only +++ recognized #4

Closed Dddddebil closed 1 year ago

Dddddebil commented 1 year ago
      import torch
      import albumentations
      import numpy as np
      from PIL import Image 

      from models.crnn import CRNN
      from utils.model_decoders import decode_predictions, decode_padded_predictions

      classes = ['+', ',', '-', '.', '1', '2', '4', '5', '7', '9']

      def inference(image_path):
          image = Image.open(image_path).convert("RGB")
          image = image.resize((530, 70), resample=Image.BILINEAR)
          image = np.array(image)

          mean = (0.485, 0.456, 0.406)
          std = (0.229, 0.224, 0.225)
          aug = albumentations.Compose([
                  albumentations.Normalize(mean, std, max_pixel_value=255.0, always_apply=True)
              ])

          image = aug(image=image)["image"]
          image = np.transpose(image, (2, 0, 1)).astype(np.float32)
          image = image[None,...] 
          image = torch.from_numpy(image)
          if str(device) == "cuda": image = image.cuda()
          image = image.float()
          with torch.no_grad():
              preds, _ = model(image)

          if model.use_ctc:
              answer = decode_predictions(preds, classes)
          else:
              answer = decode_padded_predictions(preds, classes)
          return answer

      if __name__ == "__main__":
          model = CRNN(dims=256,
              num_chars=10, 
              use_attention=True,
              use_ctc=True,
              grayscale=True,
          )
          device = torch.device("cuda")
          model.to(device)
          model.load_state_dict(torch.load("./logs/crnn.pth"))
          model.eval()

          filepath = "dataset/516874.jpeg"
          answer = inference(filepath)
          print(f"text: {answer}")`

Image size: torch.Size([1, 1, 70, 530]) text: ['+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+', '+']

GabrielDornelles commented 1 year ago

Sorry, I dont understand, can you explain it?

Dddddebil commented 1 year ago

I trained AI on 9 characters captchas, during the training I was shown 99% accuracy without recognition errors, when I try to recognize a picture, I get output only from ++++++

GabrielDornelles commented 1 year ago

I need you to provide me some answers:

Also, the first class of your list of classes should be the CTC token, in this case denoted by "∅" in training. When you train the model, it outputs in the terminal the list of classes, I recommend you to copy it and use in the inference.

Dddddebil commented 1 year ago

thanks!

GabrielDornelles commented 1 year ago

Did you solve the issue?

Dddddebil commented 1 year ago

Yes, but do you know why the letters are repeated python3 inference.py dataset/761354.jpeg text: ['7', '7', ‘6’, ‘6’, ‘1’, ‘3’, ‘5’, ‘4’]

GabrielDornelles commented 1 year ago

It could either be a decoding problem (although I don't think it is) or a training problem. The final predicted string goes in a decoding process to remove those duplicates, but if it still appears, it means that the model truly thinks that there is two '7' and two '6'.

To make this better you either need more training data or train a bit longer and see if that solves your problem.

The model output is actually a long list of predictions, when you use CTC the model actually predicts a class for each stripe. As normal, your numbers appear more than in one stripe, because of this, the model output have duplicates, but we insert the blank token "∅" to denote blank space.

That was maybe a confuse explanation but the point is, the model predicts duplicates for all your digits, but it also classify empty space ("∅"), so when there is this blank token between your letters/numbers, we suppose that there is really two of those numbers in sequence (say 77∅77 would be '77', but 7777 would simply be '7' because there is no blank in between).

Dddddebil commented 1 year ago

Okey, I'll get the amount of training data from 3500 to 7000 images and train 100 circles instead of 50

GabrielDornelles commented 1 year ago

Put enough epochs to see a good validation accuracy (at a certain point you won't get better acc, then you can stop), about the number of images, the higher the better.

Dddddebil commented 1 year ago
    size mismatch for linear.weight: copying a param with shape torch.Size([256, 960]) from checkpoint, the shape in current model is torch.Size([256, 832]).

Why? I haven't changed anything I've been teaching on, and that's what I'm running on.

Dddddebil commented 1 year ago

if i changed nn.Linear to 960 RuntimeError: mat1 and mat2 shapes cannot be multiplied (133x1152 and 960x256) WTF

GabrielDornelles commented 1 year ago

This linear layer is dynamically calculated based on your input size. It looks like your input shape changed, as your error says, your model saved weights is a linear layer of torch.Size([256, 960]), but your model is now dynamically calculating ([256, 832]) based on your input image.

    size mismatch for linear.weight: copying a param with shape torch.Size([256, 960]) from checkpoint, the shape in current model is torch.Size([256, 832]).

Why? I haven't changed anything I've been teaching on, and that's what I'm running on.