facebookresearch / nougat

Implementation of Nougat Neural Optical Understanding for Academic Documents
https://facebookresearch.github.io/nougat/
MIT License
8.98k stars 567 forks source link

Allow torch script export #54

Open chophilip21 opened 1 year ago

chophilip21 commented 1 year ago

Hi, first of all, thanks a lot for providing such an amazing model.

Probably this isn't on the top of your agenda, but your models are not compatible with TensorRT or TorchScript due to some of the locations in your script that uses numpy library. I'm referring to:

from nougat.transforms import train_transform, test_transform

These codes uses things like np.asarray, which triggers these errors:

Traceback (most recent call last):
  File "/home/philip/latex_training/deploy.py", line 221, in <module>
    main()
  File "/home/philip/latex_training/deploy.py", line 149, in main
    module = torch.jit.script(nougat_wrapper)
  File "/home/philip/latex_training/env/lib/python3.10/site-packages/torch/jit/_script.py", line 1284, in script
    return torch.jit._recursive.create_script_module(
  File "/home/philip/latex_training/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 480, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/home/philip/latex_training/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 542, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/home/philip/latex_training/env/lib/python3.10/site-packages/torch/jit/_script.py", line 614, in _construct
    init_fn(script_module)
  File "/home/philip/latex_training/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 520, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/home/philip/latex_training/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 542, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/home/philip/latex_training/env/lib/python3.10/site-packages/torch/jit/_script.py", line 614, in _construct
    init_fn(script_module)
  File "/home/philip/latex_training/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 520, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/home/philip/latex_training/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 546, in create_script_module_impl
    create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
  File "/home/philip/latex_training/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 397, in create_methods_and_properties_from_stubs
    concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
  File "/home/philip/latex_training/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 867, in try_compile_fn
    return torch.jit.script(fn, _rcb=rcb)
  File "/home/philip/latex_training/env/lib/python3.10/site-packages/torch/jit/_script.py", line 1341, in script
    fn = torch._C._jit_script_compile(
RuntimeError: 
Python builtin <built-in function asarray> is currently not supported in Torchscript:
  File "/home/philip/latex_training/env/lib/python3.10/site-packages/nougat/transforms.py", line 18
    def f(im):
        return transform(image=np.asarray(im))["image"]
                               ~~~~~~~~~~ <--- HERE
'f' is being compiled since it was called from 'SwinEncoder.__to_tensor_getter'
  File "/home/philip/latex_training/env/lib/python3.10/site-packages/nougat/model.py", line 144
    def to_tensor(self):
        if self.training:
            return train_transform
                   ~~~~~~~~~~~~~~~ <--- HERE
        else:
            return test_transform

As far as I know, Pytorch has no plans to make numpy arrays compatible.

lukas-blecher commented 1 year ago

If you want to use the model in inference mode only, you don't need the train transform. The test transform is just a normalization. You could probably overwrite this function with the image net normalization.

The training augmentations are implemented without torch support. But you could try to recreate it with torchvision transforms.

What did you execute to get this error? Maybe I could look into it a bit if you provide the code.

chophilip21 commented 1 year ago

Hi, Thanks for responding!

I have modified your predict.py to see if I can convert the model into torchscript. As you have mentioned, defining a wrapper like this below to prevent the preprocess scripts (prepare_input) being called for the Image does get me pass the first problem regarding train_transform.

class NougatWrapper(torch.nn.Module):
    def __init__(self, model):
        super(NougatWrapper, self).__init__()
        self.model = model

    def forward(self, image):
        output = self.model.inference(image_tensors=image)
        return output

So the torchscript logic is something like:

model = NougatModel.from_pretrained(args.checkpoint).to(torch.bfloat16)
nougat_wrapper = NougatWrapper(model)
nougat_wrapper.eval()
script_model = torch.jit.script(nougat_wrapper)

However, above fails as a lot of the modules related to transformers/models/mbart are not convertable.

File "/home/chophilip21/nougat/predict.py", line 188, in <module>
    main()
  File "/home/chophilip21/nougat/predict.py", line 139, in main
    script_model = torch.jit.script(nougat_wrapper)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_script.py", line 1284, in script
    return torch.jit._recursive.create_script_module(
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 480, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 542, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_script.py", line 614, in _construct
    init_fn(script_module)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 520, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 542, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_script.py", line 614, in _construct
    init_fn(script_module)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 520, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 542, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_script.py", line 614, in _construct
    init_fn(script_module)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 520, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 542, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_script.py", line 614, in _construct
    init_fn(script_module)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 520, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 492, in create_script_module_impl
    method_stubs = stubs_fn(nn_module)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 761, in infer_methods_to_compile
    stubs.append(make_stub_from_method(nn_module, method))
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 73, in make_stub_from_method
    return make_stub(func, method_name)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 58, in make_stub
    ast = get_jit_def(func, name, self_name="RecursiveScriptModule")
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/frontend.py", line 297, in get_jit_def
    return build_def(parsed_def.ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/frontend.py", line 335, in build_def
    param_list = build_param_list(ctx, py_def.args, self_name, pdt_arg_types)
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/frontend.py", line 359, in build_param_list
    raise NotSupportedError(ctx_range, _vararg_kwarg_err)
torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
  File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/transformers/models/mbart/modeling_mbart.py", line 1703
    def forward(self, *args, **kwargs):
                              ~~~~~~~ <--- HERE
        return self.decoder(*args, **kwargs)

I did manage to get pass above **kwargs restriction by manually laying out all the parameters, but unfortunately there are whole bunch of other errors related to converting mbart.

chophilip21 commented 1 year ago

Okay, so instead of torch.script, tracing works without any issues.

a, b = next(iter(dataloader))
nougat_wrapper = NougatWrapper(model)
nougat_wrapper.eval()
script_model = torch.jit.trace(nougat_wrapper, a)
script_model.save("nougat.pt")

But I can see that the output of the network is a a dictionary of lists:

output = {
          "predictions": list(),
          "sequences": list(),
          "repeats": list(),
          "repetitions": list(),
      }

which causes RuntimeError, which kind of makes sense as you can see that it's a string output of the input pdf document.

Traceback (most recent call last):
  File "/home/chophilip21/latex_training/deploy.py", line 215, in <module>
    main()
  File "/home/chophilip21/latex_training/deploy.py", line 146, in main
    script_model = torch.jit.trace(nougat_wrapper, a)
  File "/home/chophilip21/miniconda3/envs/env/lib/python3.9/site-packages/torch/jit/_trace.py", line 794, in trace
    return trace_module(
  File "/home/chophilip21/miniconda3/envs/env/lib/python3.9/site-packages/torch/jit/_trace.py", line 1056, in trace_module
    module._c._create_method_from_trace(
RuntimeError: Tracer cannot infer type of {'predictions': ["\n\n# Beyond Linear Algebra\n\nBernd Sturmfels\n\n###### Abstract\n\nOur title challenges the reader to venture beyond linear algebra in designing models and in thinking about numerical algorithms for identifying solutions. This article accompanies the author's lecture at the International Congress of Mathematicians 2022. It covers recent advances in the study of critical point equations in optimization and statistics, and it explores the role of nonlinear algebra in the study of linear PDE with constant coefficients.\n\n"], 'sequences': tensor([[    0,    25, 15337,  7834, 14229,   221,   221, 31226,   300, 38411,
            92,  1296,   221,   221,  3323,   638,  2922,   221,   221,  5302,
          9928,  6320,   286,  8988,   321, 42898,  5497,  1684,  3521,   301,
         13099,  1287,   312,   301, 16199,  1369,  3247,  3343,   345,  8091,
          2459,    36,   732,  4190, 44919,   286,  4622,    29,   105, 25473,
           434,   286,  5764, 18043,   299, 25389,   896,  4271,   243,    40,
            38,    40,    40,    36,  1077, 10449,  2865, 10498,   301,   286,
           740,   299,  2679,  1383,  2330,   301,  3700,   312,  5462,    34,
           312,   491, 25960,   286,  1673,   299,  3802,  3521,   301,   286,
           740,   299,  1684, 18923,   363,  1932,  3538,    36,     2]]), 'repeats': [None], 'repetitions': ["# Beyond Linear Algebra\n\nBernd Sturmfels\n\n###### Abstract\n\nOur title challenges the reader to venture beyond linear algebra in designing models and in thinking about numerical algorithms for identifying solutions. This article accompanies the author's lecture at the International Congress of Mathematicians 2022. It covers recent advances in the study of critical point equations in optimization and statistics, and it explores the role of nonlinear algebra in the study of linear PDE with constant coefficients."]}
:Dictionary inputs to traced functions must have consistent type. Found List[str] and Tensor

It's probably because postprocess is being called for output['predictions'] and maybe I need to have this outside of the model definition. I need to dig deeper on this, but I'm not entirely sure if the way I am approaching this is correct though. If you have any better ideas, please let me know!