Closed yong-asial closed 1 year ago
did you modify the default config? or can you provide your config and describe what you modify
@andyjpaddle i didn't modify the default config but i used my own dataset.
did you modify the default dict? or could you provide your train config?
@andyjpaddle
Sure, this is the configuration. You can ignore the training/evaluation dataset (since it is just for testing purpose).
Global:
use_gpu: True
epoch_num: 10
log_smooth_window: 20
print_batch_step: 50
save_model_dir: ./output/rec/r45_abinet/
save_epoch_step: 1
eval_batch_step: [0, 2000]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words_en/word_10.png
character_dict_path: ppocr/utils/dict/oscar_japan.txt
max_text_length: 150
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_abinet.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.99
clip_norm: 20.0
lr:
name: Piecewise
decay_epochs: [6]
values: [0.0001, 0.00001]
regularizer:
name: 'L2'
factor: 0.
Architecture:
model_type: rec
algorithm: ABINet
in_channels: 3
Transform:
Backbone:
name: ResNet45
Head:
name: ABINetHead
use_lang: True
iter_size: 3
Loss:
name: CELoss
ignore_index: &ignore_index 4500
PostProcess:
name: ABINetLabelDecode
Metric:
name: RecMetric
main_indicator: acc
Train:
dataset:
name: SimpleDataSet
data_dir: evaluation
label_file_list:
- evaluation/japanese-single-line-barna-modified-resized.txt
ratio_list:
- 1.0
transforms:
- DecodeImage:
img_mode: RGB
channel_first: False
- ABINetRecAug:
- ABINetLabelEncode:
ignore_index: *ignore_index
- ABINetRecResizeImg:
image_shape: [3, 32, 128]
- KeepKeys:
keep_keys: ['image', 'label', 'length']
loader:
shuffle: True
batch_size_per_card: 48
drop_last: True
num_workers: 2
use_shared_memory: True
Eval:
dataset:
name: SimpleDataSet
data_dir: evaluation
label_file_list:
- evaluation/japanese-single-line-barna-modified-resized.txt
ratio_list:
- 1.0
transforms:
- DecodeImage:
img_mode: RGB
channel_first: False
- ABINetLabelEncode:
ignore_index: *ignore_index
- ABINetRecResizeImg:
image_shape: [3, 32, 128]
- KeepKeys:
keep_keys: ['image', 'label', 'length']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 12
num_workers: 2
use_shared_memory: True
And here is the dictionary oscar_japan.txt It has 4400 characters in total.
Maybe what differences from the default one are
Head: name: ABINetHead use_lang: True iter_size: 3
pls add a param max_length: 150
to Head, refer code is here https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/ppocr/modeling/heads/rec_abinet_head.py#L183
ABINet will resize the image to 32*128 when training, the most appropriate max_text_length should be 25. If the max_text_length is changed, it is recommended to change the resize parameter (ABINetRecResizeImg) at the same time.
Thank you for the answer @andyjpaddle @Topdu
I added the max_length
to the Head but still got error with the dimension mismatch.
Global:
use_gpu: True
epoch_num: 10
log_smooth_window: 20
print_batch_step: 50
save_model_dir: ./output/rec/r45_abinet/
save_epoch_step: 20
# evaluation is run every 2000 iterations
eval_batch_step: [0, 2000]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path: ppocr/utils/dict/oscar_japan.txt
max_text_length: &max_text_length 150
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_abinet.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.99
clip_norm: 20.0
lr:
name: Piecewise
decay_epochs: [6]
values: [0.0001, 0.00001]
regularizer:
name: 'L2'
factor: 0.
Architecture:
model_type: rec
algorithm: ABINet
in_channels: 3
Transform:
Backbone:
name: ResNet45
Head:
name: ABINetHead
use_lang: True
iter_size: 3
max_length: *max_text_length
Loss:
name: CELoss
ignore_index: &ignore_index 4440 # Must be greater than the number of character classes
PostProcess:
name: ABINetLabelDecode
Metric:
name: RecMetric
main_indicator: acc
Train:
dataset:
...
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- ABINetRecAug:
- ABINetLabelEncode: # Class handling label
ignore_index: *ignore_index
- ABINetRecResizeImg:
image_shape: [3, 32, 768]
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 24
drop_last: True
num_workers: 2
Eval:
dataset:
...
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- ABINetLabelEncode: # Class handling label
ignore_index: *ignore_index
- ABINetRecResizeImg:
image_shape: [3, 32, 768]
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 1
num_workers: 2
use_shared_memory: False
Here is the error
ValueError: (InvalidArgument) Broadcast dimension mismatch. Operands could not be broadcast together with the shape of X = [1536, 24, 512] and the shape of Y = [256, 1, 512]. Received [1536] in X is not equal to [256] in Y at i:0.
[Hint: Expected x_dims_array[i] == y_dims_array[i] || x_dims_array[i] <= 1 || y_dims_array[i] <= 1 == true, but received x_dims_array[i] == y_dims_array[i] || x_dims_array[i] <= 1 || y_dims_array[i] <= 1:0 != true:1.] (at /paddle/paddle/fluid/operators/elementwise/elementwise_op_function.h:240)
[operator < elementwise_add > error
If i changed batch size to 12, the error is as following
ValueError: (InvalidArgument) Broadcast dimension mismatch. Operands could not be broadcast together with the shape of X = [1536, 12, 512] and the shape of Y = [256, 1, 512]. Received [1536] in X is not equal to [256] in Y at i:0.
[Hint: Expected x_dims_array[i] == y_dims_array[i] || x_dims_array[i] <= 1 || y_dims_array[i] <= 1 == true, but received x_dims_array[i] == y_dims_array[i] || x_dims_array[i] <= 1 || y_dims_array[i] <= 1:0 != true:1.] (at /paddle/paddle/fluid/operators/elementwise/elementwise_op_function.h:240)
[operator < elementwise_add > error
Similarly, If i changed image_shape
to something else, it will output different error but still something with mismatch dimension.
Are there any hard requirement for the dataset? like the training dataset has to be 32*128?
If possible, could you send me the sample of training dataset (either in SimpleDataset or LMDBDataset that is working?
Thank you.
ppocr/modeling/head/rec_abinet_head.py
class PositionAttention(nn.Layer):
def __init__(self,
max_length,
in_channels=512,
num_channels=64,
h=8, # H//4
w=32, # W//4
mode='nearest',
**kwargs):
Modifying the h
and w
to H
, W
of image_shape.
Thank you @Topdu
image_shape: [3, 32, 768]
I changed it as
class PositionAttention(nn.Layer):
def __init__(self,
max_length,
in_channels=512,
num_channels=64,
h=8, # 32/4
w=192, # 768/4
mode='nearest',
**kwargs):
But i got the same error
ValueError: (InvalidArgument) Broadcast dimension mismatch. Operands could not be broadcast together with the shape of X = [1536, 12, 512] and the shape of Y = [256, 1, 512]. Received [1536] in X is not equal to [256] in Y at i:0.
[Hint: Expected x_dims_array[i] == y_dims_array[i] || x_dims_array[i] <= 1 || y_dims_array[i] <= 1 == true, but received x_dims_array[i] == y_dims_array[i] || x_dims_array[i] <= 1 || y_dims_array[i] <= 1:0 != true:1.] (at /paddle/paddle/fluid/operators/elementwise/elementwise_op_function.h:240)
[operator < elementwise_add > error]
We have the same X = [1536, 12, 512] and the shape of Y = [256, 1, 512]
regardless the change. Do you know where/how these tensors are calculated?
Thank you
Update:
X = [1536, 12, 512]
: It seems that the 1536
is getting from the width 768
* 2. It is always 2. So if i changed the width to 128
(image_shape: [3, 32, 128]), I don't get the above error since now the X[0]=256 same as Y[0] (It always Y = [256, 1, 512]
regardless the configuration). However, i got another different error.Traceback (most recent call last):
File "tools/train.py", line 208, in <module>
main(config, device, logger, vdl_writer)
File "tools/train.py", line 183, in main
amp_level, amp_custom_black_list)
File "/app/tools/program.py", line 297, in train
avg_loss.backward()
File "/usr/local/lib/python3.7/dist-packages/decorator.py", line 232, in fun
return caller(func, *(extras + args), **kw)
File "/usr/local/lib/python3.7/dist-packages/paddle/fluid/wrapped_decorator.py", line 25, in __impl__
return wrapped_func(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/paddle/fluid/framework.py", line 229, in __impl__
return func(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/paddle/fluid/dygraph/varbase_patch_methods.py", line 249, in backward
framework._dygraph_tracer())
RuntimeError: (NotFound) No grad node corresponding to grad Tensor (auto_1442_@GRAD) was found.
[Hint: Expected find_grad_node_of_var == true, but received find_grad_node_of_var:0 != true:1.] (at /paddle/paddle/fluid/imperative/basic_engine.cc:245
The problem of inconsistent Tensor shapes is common and the most effective way to debug them is line by line.
@Topdu I managed to make it work with image width of 128. However, with different width, I got another tensor mismatch
ValueError: (InvalidArgument) Broadcast dimension mismatch. Operands could not be broadcast together with the shape of X = [12, 64, 2, 26] and the shape of Y = [12, 64, 2, 25]. Received [26] in X is not equal to [25] in Y at i:3.
Perhaps, you might know something about this tensor X = [12, 64, 2, 26]
. Which part i should look into?
Thank you
This issue has been automatically marked as stale because it has not had recent activity. It will be closed in 7 days if no further activity occurs. Thank you for your contributions.
@Topdu I managed to make it work with image width of 128. However, with different width, I got another tensor mismatch
ValueError: (InvalidArgument) Broadcast dimension mismatch. Operands could not be broadcast together with the shape of X = [12, 64, 2, 26] and the shape of Y = [12, 64, 2, 25]. Received [26] in X is not equal to [25] in Y at i:3.
Perhaps, you might know something about this tensor
X = [12, 64, 2, 26]
. Which part i should look into?Thank you
ppocr/modeling/heads/rec_abinet_head.py
and
@yong-asial have a try!!! @Topdu Can I submit a pull request?
ABINet will resize the image to 32*128 when training, the most appropriate max_text_length should be 25. If the max_text_length is changed, it is recommended to change the resize parameter (ABINetRecResizeImg) at the same time.
hello sir, I got the same issue as describe above. can you provide information about how we should adjust the 'max_text_length' in contrast to 'ABINetRecResizeImg' to avoid mismatch errors? I could get accuracy based on resize image shape 32*128 on my own dataset. however, the model only recognize the first word in the image and my dataset is in phrase-level. i think this issue happens because of setting 'max_text_length = 25' and if i want to set this parameter into higher or lower measure, i can't train the model and i get the described error. do you have any solution to these issues? thanks.
ABINet will resize the image to 32*128 when training, the most appropriate max_text_length should be 25. If the max_text_length is changed, it is recommended to change the resize parameter (ABINetRecResizeImg) at the same time.
hello sir, I got the same issue as describe above. can you provide information about how we should adjust the 'max_text_length' in contrast to 'ABINetRecResizeImg' to avoid mismatch errors? I could get accuracy based on resize image shape 32*128. however, the model only recognize the first word in the image and my dataset is in phrase-level. i think this issue happens because of setting 'max_text_length = 25' and if i want to set this parameter into higher or lower measure, i can't train the model and i get the describe error. do you have any solution to these issues? thanks.
git check out dygraph
git check out dygraph Hi brother, what do you mean by dygraph?
请提供下述完整信息以便快速定位问题/Please provide the following information to quickly locate the problem
I tried to train with ABInet (rec_r45_abinet.yml) but i got above or similar error. When i tried changing the batch size (or image size), it will display different error but still something with dimension mismatch.
Do you have any guidelines to train with this algorithm (ABInet)? I got the same error as well with VisionLan. I tried with both SimpleDataset and LMDBDataset but neither of them work.
Note: The same dataset, i got it successful training with CRNN and SVTR.