ku21fan / STR-Fewer-Labels

Scene Text Recognition (STR) methods trained with fewer real labels (CVPR 2021)
MIT License
173 stars 27 forks source link

Predict after training on custom dataset #10

Closed lyminhuit closed 1 year ago

lyminhuit commented 1 year ago

Hi, I have trained on Vietnamese custom dataset with this command:

!CUDA_VISIBLE_DEVICES=0 python train.py \
                --select_data / \
                --model_name TRBA \
                --exp_name CRNN_aug \
                --Aug Blur5-Crop99  \
                --train_data train \
                --valid_data val \
                --character 'aàảãáạăằẳẵắặâầẩẫấậbcdđeèẻẽéẹêềểễếệfghiìỉĩíịjklmnoòỏõóọôồổỗốộơờởỡớợpqrstuùủũúụưừửữứựvwxyỳỷỹýỵz0123456789' \
                --batch_ratio 1 

after 50k iters, i predicted with best_score.pth by this command:

!python demo.py \
        --model_name TRBA \
        --image_folder /content/drive/MyDrive/UIT_CHALLENGE_2022/result/crop_img \
        --saved_model /content/drive/MyDrive/UIT_CHALLENGE_2022/src/STR-Fewer-Labels/best_score_new_reg.pth

The error:

of tokens and characters: 99

model input parameters 32 100 20 3 512 256 99 25 TPS ResNet BiLSTM Attn loading pretrained model from /content/drive/MyDrive/UIT_CHALLENGE_2022/src/STR-Fewer-Labels/best_score_new_reg.pth Traceback (most recent call last): File "demo.py", line 194, in demo(opt) File "demo.py", line 47, in demo model.load_state_dict(torch.load(opt.saved_model, map_location=device)) File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1667, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for DataParallel: size mismatch for module.Prediction.generator.weight: copying a param with shape torch.Size([108, 256]) from checkpoint, the shape in current model is torch.Size([99, 256]). size mismatch for module.Prediction.generator.bias: copying a param with shape torch.Size([108]) from checkpoint, the shape in current model is torch.Size([99]). size mismatch for module.Prediction.char_embeddings.weight: copying a param with shape torch.Size([108, 256]) from checkpoint, the shape in current model is torch.Size([99, 256]).

Pls helps, thanks!

ku21fan commented 1 year ago

If you train the model with the below option --character 'aàảãáạăằẳẵắặâầẩẫấậbcdđeèẻẽéẹêềểễếệfghiìỉĩíịjklmnoòỏõóọôồổỗốộơờởỡớợpqrstuùủũúụưừửữứựvwxyỳỷỹýỵz0123456789' \ , you also need it to predict.

Thus, try this command

!python demo.py \
        --model_name TRBA \
        --character 'aàảãáạăằẳẵắặâầẩẫấậbcdđeèẻẽéẹêềểễếệfghiìỉĩíịjklmnoòỏõóọôồổỗốộơờởỡớợpqrstuùủũúụưừửữứựvwxyỳỷỹýỵz0123456789' \
        --image_folder /content/drive/MyDrive/UIT_CHALLENGE_2022/result/crop_img \
        --saved_model /content/drive/MyDrive/UIT_CHALLENGE_2022/src/STR-Fewer-Labels/best_score_new_reg.pth
lyminhuit commented 1 year ago

it working, thank you sir!!