open-mmlab / mmocr

OpenMMLab Text Detection, Recognition and Understanding Toolbox
https://mmocr.readthedocs.io/en/dev-1.x/
Apache License 2.0
4.3k stars 746 forks source link

torch.jit.trace error: hope to trace dbnet_r18 model #62

Open alanguo1234 opened 3 years ago

alanguo1234 commented 3 years ago

Hi
I hope to jit.trace the dbnet_r18 model, meet the below error:

mmocr# python demo/image_demo.py demo/demo_text_det.jpg configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py  ./py_DBNet_r18/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth  demo/demo_text_det_pred.jpg
Traceback (most recent call last):
  File "demo/image_demo.py", line 52, in <module>
    main()
  File "demo/image_demo.py", line 36, in main
    script_fun = torch.jit.trace(model, tmp_in)
  File "/opt/conda/lib/python3.7/site-packages/torch/jit/__init__.py", line 875, in trace
    check_tolerance, _force_outplace, _module_class)
  File "/opt/conda/lib/python3.7/site-packages/torch/jit/__init__.py", line 1021, in trace_module
    module = make_module(mod, _module_class, _compilation_unit)
  File "/opt/conda/lib/python3.7/site-packages/torch/jit/__init__.py", line 720, in make_module
    return _module_class(mod, _compilation_unit=_compilation_unit)
  File "/opt/conda/lib/python3.7/site-packages/torch/jit/__init__.py", line 1884, in __init__
    tmp_module._modules[name] = make_module(submodule, TracedModule, _compilation_unit=None)
  File "/opt/conda/lib/python3.7/site-packages/torch/jit/__init__.py", line 720, in make_module
    return _module_class(mod, _compilation_unit=_compilation_unit)
  File "/opt/conda/lib/python3.7/site-packages/torch/jit/__init__.py", line 1884, in __init__
    tmp_module._modules[name] = make_module(submodule, TracedModule, _compilation_unit=None)
  File "/opt/conda/lib/python3.7/site-packages/torch/jit/__init__.py", line 720, in make_module
    return _module_class(mod, _compilation_unit=_compilation_unit)
  File "/opt/conda/lib/python3.7/site-packages/torch/jit/__init__.py", line 1884, in __init__
    tmp_module._modules[name] = make_module(submodule, TracedModule, _compilation_unit=None)
  File "/opt/conda/lib/python3.7/site-packages/torch/jit/__init__.py", line 703, in make_module
    elif torch._jit_internal.module_has_exports(mod):
  File "/opt/conda/lib/python3.7/site-packages/torch/_jit_internal.py", line 438, in module_has_exports
    item = getattr(mod, name)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 594, in __getattr__
    type(self).__name__, name))
AttributeError: 'ConvModule' object has no attribute 'norm'

I have modified the image_demo.py as below :

from argparse import ArgumentParser

import mmcv

from mmdet.apis import init_detector
from mmocr.apis.inference import model_inference
from mmocr.datasets import build_dataset  # noqa: F401
from mmocr.models import build_detector  # noqa: F401
import torch

def main():
    parser = ArgumentParser()
    parser.add_argument('img', help='Image file.')
    parser.add_argument('config', help='Config file.')
    parser.add_argument('checkpoint', help='Checkpoint file.')
    parser.add_argument('save_path', help='Path to save visualized image.')
    parser.add_argument(
        '--device', default='cpu', help='Device used for inference.')
    parser.add_argument(
        '--imshow',
        action='store_true',
        help='Whether show image with OpenCV.')
    args = parser.parse_args()

    # build the model from a config file and a checkpoint file
    model = init_detector(args.config, args.checkpoint, device=args.device)
    if model.cfg.data.test['type'] == 'ConcatDataset':
        model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
            0].pipeline

    device = torch.device('cpu')
    chkpt=torch.load("./py_DBNet_r18/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth", map_location=device)
    model.load_state_dict(chkpt['state_dict'], strict=False)

    tmp_in = torch.rand(64,3,7,7)
    script_fun = torch.jit.trace(model, tmp_in)
    script_fun.save("./py_DBNet_r18/dbnet_r18_fpnc_trace.pt")

    # test a single image
    result =  model_inference(model, args.img)
    print(f'result: {result}')

    # show the results
    img = model.show_result(args.img, result, out_file=None, show=False)

    mmcv.imwrite(img, args.save_path)
    if args.imshow:
        mmcv.imshow(img, 'predicted results')

if __name__ == '__main__':
    main()
jeffreykuang commented 3 years ago

I used the following command line

python demo/image_demo.py demo/demo_text_det.jpg configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py checkout/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth demo/out.jpg

I can get the correct output image as follows:

out

jeffreykuang commented 3 years ago
'ConvModule'

'ConvModule' is one module of mmcv. It does have norm attribute. I am not sure whether the jit mistreat the 'ConvModule' as pytorch's own module or not.

alanguo1234 commented 3 years ago

I used the following command line

python demo/image_demo.py demo/demo_text_det.jpg configs/textdet/dbnet/dbnet_r18_fpnc_1200e_icdar2015.py checkout/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth demo/out.jpg

I can get the correct output image as follows:

out

HI , if don't add the jit.trace code into image_demo.py, it run ok

alanguo1234 commented 3 years ago
'ConvModule'

'ConvModule' is one module of mmcv. It does have norm attribute. I am not sure whether the jit mistreat the 'ConvModule' as pytorch's own module or not.

do u have the wechat or wechat group? hope to contact u , thanks

jeffreykuang commented 3 years ago
'ConvModule'

'ConvModule' is one module of mmcv. It does have norm attribute. I am not sure whether the jit mistreat the 'ConvModule' as pytorch's own module or not.

do u have the wechat or wechat group? hope to contact u , thanks

image image