robertknight / ocrs-models

PyTorch models for the ocrs OCR engine
53 stars 8 forks source link

ONNX export error when run on system with CUDA device available #29

Closed Phaired closed 2 months ago

Phaired commented 2 months ago

Hey, I generated a synthetic dataset for my needs and trained the models. I wanted to export it to ONNX and then to RTEN, but it seems like I'm having trouble converting it to ONNX. Am I missing something?

root@6e1d2e750775:/workspace/ocr/ocrs-models# ls
Makefile   a.py     datasets  layout-scraper  ocrs_models  poetry.lock     synt                          text-detection-checkpoint.pt.bak
README.md  box.png  docs      mypy.ini        owndata      pyproject.toml  text-detection-checkpoint.pt  wandb
root@6e1d2e750775:/workspace/ocr/ocrs-models# poetry run python -m ocrs_models.train_detection hiertext datasets/hiertext/ --checkpoint text-detection-checkpoint.pt --export text-detection.onnx
True
datasets/hiertext/
Training dataset: images 8000 in 2000 batches
Validation dataset: images 1600 in 400 batches
Model param count: 622122
/workspace/ocr/ocrs-models/ocrs_models/train_detection.py:212: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  checkpoint = torch.load(filename, map_location=device)
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/workspace/ocr/ocrs-models/ocrs_models/train_detection.py", line 491, in <module>
    main()
  File "/workspace/ocr/ocrs-models/ocrs_models/train_detection.py", line 399, in main
    torch.onnx.export(
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/onnx/utils.py", line 551, in export
    _export(
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/onnx/utils.py", line 1648, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/onnx/utils.py", line 1170, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/onnx/utils.py", line 1046, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/onnx/utils.py", line 950, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/jit/_trace.py", line 1497, in _get_trace_graph
    outs = ONNXTracedModule(
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/jit/_trace.py", line 141, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/jit/_trace.py", line 132, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1543, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/workspace/ocr/ocrs-models/ocrs_models/models.py", line 132, in forward
    x = self.in_conv(x)
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1543, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/workspace/ocr/ocrs-models/ocrs_models/models.py", line 41, in forward
    return self.seq(x)
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1543, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/nn/modules/container.py", line 219, in forward
    input = module(input)
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1543, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/workspace/ocr/ocrs-models/ocrs_models/models.py", line 28, in forward
    return self.seq(x)
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1543, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/nn/modules/container.py", line 219, in forward
    input = module(input)
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1543, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 458, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 454, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
root@6e1d2e750775:/workspace/ocr/ocrs-models# poetry run pip show torch
Name: torch
Version: 2.4.0
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD-3
Location: /root/.cache/pypoetry/virtualenvs/ocrs-models-E9ELUCGY-py3.10/lib/python3.10/site-packages
Requires: filelock, fsspec, jinja2, networkx, nvidia-cublas-cu12, nvidia-cuda-cupti-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-runtime-cu12, nvidia-cudnn-cu12, nvidia-cufft-cu12, nvidia-curand-cu12, nvidia-cusolver-cu12, nvidia-cusparse-cu12, nvidia-nccl-cu12, nvidia-nvtx-cu12, sympy, triton, typing-extensions
Required-by: torchvision

running on an A100 CUDA 12.2 Driver 535.154.05

robertknight commented 2 months ago

RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

This error means that one of the tensors is on the CPU and the other is on the GPU. In this case the dummy model input used during export is on the CPU but the model weights are on the GPU. I think when I've exported models in the past, I always did so on a device with only a CPU. So this is a bug in the training script for not taking into account the possibility that you're doing the export on a system with a GPU (!)

In torch.onnx.export calls, both the model and the dummy input need to be on the same device. The model is on the GPU here, so to move the dummy input to the GPU this would involve a change in train_detection.py from:

        test_batch = next(iter(val_dataloader))
        test_image = test_batch["image"][0:1]

        torch.onnx.export(

To:

        test_batch = next(iter(val_dataloader))
        test_image = test_batch["image"][0:1].to(device)

        torch.onnx.export(

The other option would be to find the line that initializes device and make it use the CPU instead.

Phaired commented 2 months ago

Ok, well it fixed the issue, but when trying to use the custom detection model with the CLI (for testing purposes), I got an error:

$ ocrs --detect-model text-detection.rten imga.png
Error: Failed to load text detection model from text-detection.rten

Caused by:
    parse error: Type `i32` at position 1313166418 is unaligned.
robertknight commented 2 months ago

Can you upload the ONNX model somewhere? Also can you confirm which version of ocrs you have installed (ocrs --version) and which version of rten-convert you used (pip show rten-convert)?

Phaired commented 2 months ago

Here is the ONNX model: text-detection.zip.

$ ocrs --version
ocrs 0.8.0

As for rten-convert, I can't provide the exact version because I installed it recently on the pod used for training(which is now deleted), so it's likely the latest version.

robertknight commented 2 months ago

Ah, I can see the problem. The .rten file format was changed recently to support larger models, but the published version of ocrs uses an older version of the rten library that does not recognize the latest format. The workaround is to pass the --v1 flag when running rten-convert to force use of the older file format:

 rten-convert --v1 text-detection.onnx

When I publish the next release of ocrs it will support the latest (V2) .rten model format.