Open chophilip21 opened 1 year ago
If you want to use the model in inference mode only, you don't need the train transform. The test transform is just a normalization. You could probably overwrite this function with the image net normalization.
The training augmentations are implemented without torch support. But you could try to recreate it with torchvision transforms.
What did you execute to get this error? Maybe I could look into it a bit if you provide the code.
Hi, Thanks for responding!
I have modified your predict.py
to see if I can convert the model into torchscript. As you have mentioned, defining a wrapper like this below to prevent the preprocess scripts (prepare_input) being called for the Image does get me pass the first problem regarding train_transform
.
class NougatWrapper(torch.nn.Module):
def __init__(self, model):
super(NougatWrapper, self).__init__()
self.model = model
def forward(self, image):
output = self.model.inference(image_tensors=image)
return output
So the torchscript logic is something like:
model = NougatModel.from_pretrained(args.checkpoint).to(torch.bfloat16)
nougat_wrapper = NougatWrapper(model)
nougat_wrapper.eval()
script_model = torch.jit.script(nougat_wrapper)
However, above fails as a lot of the modules related to transformers/models/mbart
are not convertable.
File "/home/chophilip21/nougat/predict.py", line 188, in <module>
main()
File "/home/chophilip21/nougat/predict.py", line 139, in main
script_model = torch.jit.script(nougat_wrapper)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_script.py", line 1284, in script
return torch.jit._recursive.create_script_module(
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 480, in create_script_module
return create_script_module_impl(nn_module, concrete_type, stubs_fn)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 542, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_script.py", line 614, in _construct
init_fn(script_module)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 520, in init_fn
scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 542, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_script.py", line 614, in _construct
init_fn(script_module)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 520, in init_fn
scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 542, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_script.py", line 614, in _construct
init_fn(script_module)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 520, in init_fn
scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 542, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_script.py", line 614, in _construct
init_fn(script_module)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 520, in init_fn
scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 492, in create_script_module_impl
method_stubs = stubs_fn(nn_module)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 761, in infer_methods_to_compile
stubs.append(make_stub_from_method(nn_module, method))
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 73, in make_stub_from_method
return make_stub(func, method_name)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/_recursive.py", line 58, in make_stub
ast = get_jit_def(func, name, self_name="RecursiveScriptModule")
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/frontend.py", line 297, in get_jit_def
return build_def(parsed_def.ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/frontend.py", line 335, in build_def
param_list = build_param_list(ctx, py_def.args, self_name, pdt_arg_types)
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/torch/jit/frontend.py", line 359, in build_param_list
raise NotSupportedError(ctx_range, _vararg_kwarg_err)
torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
File "/home/chophilip21/nougat/env/lib/python3.10/site-packages/transformers/models/mbart/modeling_mbart.py", line 1703
def forward(self, *args, **kwargs):
~~~~~~~ <--- HERE
return self.decoder(*args, **kwargs)
I did manage to get pass above **kwargs
restriction by manually laying out all the parameters, but unfortunately there are whole bunch of other errors related to converting mbart.
Okay, so instead of torch.script, tracing works without any issues.
a, b = next(iter(dataloader))
nougat_wrapper = NougatWrapper(model)
nougat_wrapper.eval()
script_model = torch.jit.trace(nougat_wrapper, a)
script_model.save("nougat.pt")
But I can see that the output of the network is a a dictionary of lists:
output = {
"predictions": list(),
"sequences": list(),
"repeats": list(),
"repetitions": list(),
}
which causes RuntimeError, which kind of makes sense as you can see that it's a string output of the input pdf document.
Traceback (most recent call last):
File "/home/chophilip21/latex_training/deploy.py", line 215, in <module>
main()
File "/home/chophilip21/latex_training/deploy.py", line 146, in main
script_model = torch.jit.trace(nougat_wrapper, a)
File "/home/chophilip21/miniconda3/envs/env/lib/python3.9/site-packages/torch/jit/_trace.py", line 794, in trace
return trace_module(
File "/home/chophilip21/miniconda3/envs/env/lib/python3.9/site-packages/torch/jit/_trace.py", line 1056, in trace_module
module._c._create_method_from_trace(
RuntimeError: Tracer cannot infer type of {'predictions': ["\n\n# Beyond Linear Algebra\n\nBernd Sturmfels\n\n###### Abstract\n\nOur title challenges the reader to venture beyond linear algebra in designing models and in thinking about numerical algorithms for identifying solutions. This article accompanies the author's lecture at the International Congress of Mathematicians 2022. It covers recent advances in the study of critical point equations in optimization and statistics, and it explores the role of nonlinear algebra in the study of linear PDE with constant coefficients.\n\n"], 'sequences': tensor([[ 0, 25, 15337, 7834, 14229, 221, 221, 31226, 300, 38411,
92, 1296, 221, 221, 3323, 638, 2922, 221, 221, 5302,
9928, 6320, 286, 8988, 321, 42898, 5497, 1684, 3521, 301,
13099, 1287, 312, 301, 16199, 1369, 3247, 3343, 345, 8091,
2459, 36, 732, 4190, 44919, 286, 4622, 29, 105, 25473,
434, 286, 5764, 18043, 299, 25389, 896, 4271, 243, 40,
38, 40, 40, 36, 1077, 10449, 2865, 10498, 301, 286,
740, 299, 2679, 1383, 2330, 301, 3700, 312, 5462, 34,
312, 491, 25960, 286, 1673, 299, 3802, 3521, 301, 286,
740, 299, 1684, 18923, 363, 1932, 3538, 36, 2]]), 'repeats': [None], 'repetitions': ["# Beyond Linear Algebra\n\nBernd Sturmfels\n\n###### Abstract\n\nOur title challenges the reader to venture beyond linear algebra in designing models and in thinking about numerical algorithms for identifying solutions. This article accompanies the author's lecture at the International Congress of Mathematicians 2022. It covers recent advances in the study of critical point equations in optimization and statistics, and it explores the role of nonlinear algebra in the study of linear PDE with constant coefficients."]}
:Dictionary inputs to traced functions must have consistent type. Found List[str] and Tensor
It's probably because postprocess
is being called for output['predictions'] and maybe I need to have this outside of the model definition. I need to dig deeper on this, but I'm not entirely sure if the way I am approaching this is correct though. If you have any better ideas, please let me know!
Hi, first of all, thanks a lot for providing such an amazing model.
Probably this isn't on the top of your agenda, but your models are not compatible with TensorRT or TorchScript due to some of the locations in your script that uses numpy library. I'm referring to:
These codes uses things like np.asarray, which triggers these errors:
As far as I know, Pytorch has no plans to make numpy arrays compatible.