mosaicml / composer

Supercharge Your Model Training
http://docs.mosaicml.com
Apache License 2.0
5.11k stars 413 forks source link

ONNX export with `dynamic_axes` does not work when applying `BlurPool` #3466

Open dneup opened 1 month ago

dneup commented 1 month ago

When I try to save a ResNet18 model to ONNX with the export_for_inference function while providing the apply_blurpool surgery algorithm as well as the dynamic axes I get the following error:

/home/project/.venv/lib/python3.10/site-packages/composer/algorithms/blurpool/blurpool_layers.py:27: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if h % 2 == 0:
/home/project/.venv/lib/python3.10/site-packages/composer/algorithms/blurpool/blurpool_layers.py:29: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if w % 2 == 0:
/home/project/.venv/lib/python3.10/site-packages/composer/algorithms/blurpool/blurpool_layers.py:31: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  return int(torch.div(h, 2)), int(torch.div(w, 2))
/home/project/.venv/lib/python3.10/site-packages/composer/algorithms/blurpool/blurpool_layers.py:76: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if (filter.shape[0] == 1) and (channels > 1):
/home/project/.venv/lib/python3.10/site-packages/composer/algorithms/blurpool/blurpool_layers.py:81: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if h + 2 * padding[0] < filter_h:
/home/project/.venv/lib/python3.10/site-packages/composer/algorithms/blurpool/blurpool_layers.py:83: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if w + 2 * padding[1] < filter_w:
Traceback (most recent call last):
  File "/home/project/test_model_export.py", line 9, in <module>
    export_for_inference(
  File "/home/project/.venv/lib/python3.10/site-packages/composer/utils/inference.py", line 258, in export_for_inference
    torch.onnx.export(
  File "/home/project/.venv/lib/python3.10/site-packages/torch/onnx/utils.py", line 516, in export
    _export(
  File "/home/project/.venv/lib/python3.10/site-packages/torch/onnx/utils.py", line 1612, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/home/project/.venv/lib/python3.10/site-packages/torch/onnx/utils.py", line 1138, in _model_to_graph
    graph = _optimize_graph(
  File "/home/project/.venv/lib/python3.10/site-packages/torch/onnx/utils.py", line 677, in _optimize_graph
    graph = _C._jit_pass_onnx(graph, operator_export_type)
  File "/home/project/.venv/lib/python3.10/site-packages/torch/onnx/utils.py", line 1956, in _run_symbolic_function
    return symbolic_fn(graph_context, *inputs, **attrs)
  File "/home/project/.venv/lib/python3.10/site-packages/torch/onnx/symbolic_helper.py", line 306, in wrapper
    return fn(g, *args, **kwargs)
  File "/home/project/.venv/lib/python3.10/site-packages/torch/onnx/symbolic_opset9.py", line 2519, in _convolution
    raise errors.SymbolicValueError(
torch.onnx.errors.SymbolicValueError: Unsupported: ONNX export of convolution for kernel of unknown shape.  [Caused by the value 'maxs defined in (%maxs : Float(*, 64, *, *, strides=[65536, 1024, 32, 1], requires_grad=1, device=cpu), %146 : Long(*, 64, *, *, device=cpu) = onnx::MaxPool[ceil_mode=0, dilations=[1, 1], kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%142), scope: torchvision.models.resnet.ResNet::/composer.algorithms.blurpool.blurpool_layers.BlurMaxPool2d::maxpool # /home/project/.venv/lib/python3.10/site-packages/torch/nn/functional.py:796:0
)' (type 'Tensor') in the TorchScript graph. The containing node has kind 'onnx::MaxPool'.] 
    (node defined in /home/project/.venv/lib/python3.10/site-packages/torch/nn/functional.py(796): _max_pool2d
/home/project/.venv/lib/python3.10/site-packages/torch/_jit_internal.py(497): fn
/home/project/.venv/lib/python3.10/site-packages/composer/algorithms/blurpool/blurpool_layers.py(150): blurmax_pool2d
/home/project/.venv/lib/python3.10/site-packages/composer/algorithms/blurpool/blurpool_layers.py(201): forward
/home/project/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py(1522): _slow_forward
/home/project/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py(1541): _call_impl
/home/project/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py(1532): _wrapped_call_impl
/home/project/.venv/lib/python3.10/site-packages/torchvision/models/resnet.py(271): _forward_impl
/home/project/.venv/lib/python3.10/site-packages/torchvision/models/resnet.py(285): forward
/home/project/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py(1522): _slow_forward
/home/project/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py(1541): _call_impl
/home/project/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py(1532): _wrapped_call_impl
/home/project/.venv/lib/python3.10/site-packages/torch/jit/_trace.py(129): wrapper
/home/project/.venv/lib/python3.10/site-packages/torch/jit/_trace.py(138): forward
/home/project/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py(1541): _call_impl
/home/project/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py(1532): _wrapped_call_impl
/home/project/.venv/lib/python3.10/site-packages/torch/jit/_trace.py(1310): _get_trace_graph
/home/project/.venv/lib/python3.10/site-packages/torch/onnx/utils.py(914): _trace_and_get_graph_from_model
/home/project/.venv/lib/python3.10/site-packages/torch/onnx/utils.py(1010): _create_jit_graph
/home/project/.venv/lib/python3.10/site-packages/torch/onnx/utils.py(1134): _model_to_graph
/home/project/.venv/lib/python3.10/site-packages/torch/onnx/utils.py(1612): _export
/home/project/.venv/lib/python3.10/site-packages/torch/onnx/utils.py(516): export
/home/project/.venv/lib/python3.10/site-packages/composer/utils/inference.py(258): export_for_inference
/home/project/test_model_export.py(9): <module>
)

    Inputs:
        #0: 142 defined in (%142 : Float(*, 64, *, *, strides=[65536, 1024, 32, 1], requires_grad=1, device=cpu) = onnx::Relu(%input.4), scope: torchvision.models.resnet.ResNet::/torch.nn.modules.activation.ReLU::relu # /home/project/.venv/lib/python3.10/site-packages/torch/nn/functional.py:1498:0
    )  (type 'Tensor')
    Outputs:
        #0: maxs defined in (%maxs : Float(*, 64, *, *, strides=[65536, 1024, 32, 1], requires_grad=1, device=cpu), %146 : Long(*, 64, *, *, device=cpu) = onnx::MaxPool[ceil_mode=0, dilations=[1, 1], kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%142), scope: torchvision.models.resnet.ResNet::/composer.algorithms.blurpool.blurpool_layers.BlurMaxPool2d::maxpool # /home/project/.venv/lib/python3.10/site-packages/torch/nn/functional.py:796:0
    )  (type 'Tensor')
        #1: 146 defined in (%maxs : Float(*, 64, *, *, strides=[65536, 1024, 32, 1], requires_grad=1, device=cpu), %146 : Long(*, 64, *, *, device=cpu) = onnx::MaxPool[ceil_mode=0, dilations=[1, 1], kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%142), scope: torchvision.models.resnet.ResNet::/composer.algorithms.blurpool.blurpool_layers.BlurMaxPool2d::maxpool # /home/project/.venv/lib/python3.10/site-packages/torch/nn/functional.py:796:0
    )  (type 'Tensor')

Environment

torch=2.3.1
torchvision=0.18.1
composer=0.23.5 

** To reproduce

Code snipeet that throws the error:

import torch
from composer.utils import export_for_inference
import torchvision
import composer.functional as cf

model = torchvision.models.resnet18()

export_for_inference(
    model=model,
    save_format="onnx",
    save_path="./model.onnx",
    sample_input=torch.rand(1, 3, 64, 64),
    dynamic_axes={"input": {0: "batch_size", 2: "height", 3: "width"}},
    surgery_algs=[cf.apply_blurpool],
)

The error seems to happen in the code for applying the blurpool operation to the MaxPool2D layer. The blur_2d function seems to be called with num_channels=-1 which triggers the dynamic control flow not supported by tracing.

The issue also persists when:

The issue disappears when:

Also, the TracerWarnings should probably be errors, as the model might not work as expected.

mvpatel2000 commented 1 month ago

@dskhudia any suggestions here? Is it just that it's not possible with dynamic axis?