mhamilton723 / FeatUp

Official code for "FeatUp: A Model-Agnostic Frameworkfor Features at Any Resolution" ICLR 2024
MIT License
1.29k stars 69 forks source link

Exporting upsamplers to ONNX #4

Open opassos opened 3 months ago

opassos commented 3 months ago

I was trying to export those models to ONNX but I am failing. I am getting the following error

SymbolicValueError: Unsupported: ONNX export of operator adaptive_avg_pool2d, output size that are not factor of input size. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues  [Caused by the value '299 defined in (%299 : int[] = prim::ListConstruct(%295, %298), scope: featup.upsamplers.JBUStack::
)' (type 'List[int]') in the TorchScript graph. The containing node has kind 'prim::ListConstruct'.] 

    Inputs:
        #0: 295 defined in (%295 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={64}](), scope: featup.upsamplers.JBUStack:: # /content/FeatUp/featup/upsamplers.py:266:0
    )  (type 'Tensor')
        #1: 298 defined in (%298 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={64}](), scope: featup.upsamplers.JBUStack:: # /content/FeatUp/featup/upsamplers.py:266:0
    )  (type 'Tensor')
    Outputs:
        #0: 299 defined in (%299 : int[] = prim::ListConstruct(%295, %298), scope: featup.upsamplers.JBUStack::
    )  (type 'List[int]')

When running:

torch.onnx.export(
    upsampler.upsampler, 
    (lr_feats, image_tensor), 
    "upsampler.onnx", 
    opset_version=16, 
    input_names=["input"],
    output_names=["output"]
)

for the DinoV2 example on colab.

opassos commented 3 months ago

I just realized you are using a custom Cuda Kernel. I imagine this won't be easy to export to ONNX right?

Is there a version of the model that uses vanilla Pytorch kernels?

mhamilton723 commented 3 months ago

Hey @opassos yes indeed we use a custom op to make this efficient. There is a way to implement this inefficiently in pytorch

the key operator that we turned into an op - the adaptive conv can be made in pytorch using the unfold operator at a large memory cost

import torch
import torch.nn.functional as F
from torch.autograd import Function
from torch.jit import script

torch.manual_seed(42)

def adaptive_conv_py_simple(input, filters):
    b, c, h1, w1 = input.shape
    b, h2, w2, f1, f2 = filters.shape
    assert f1 == f2
    kernel_size = f1

    t_filters = filters.reshape(b, h2, w2, f1 * f2)

    patches = torch.nn.Unfold(kernel_size)(input) \
        .view((b, c, f1 * f2, h2, w2))

    return torch.einsum('bhwf,bcfhw->bchw', t_filters, patches)
opassos commented 3 months ago

Thank you very much, will try this asap and report back the results.

opassos commented 3 months ago

Exporting the dino models have their own issues (and workarounds) with some ops not being supported so I started trying the resnet50. When running

upsampler = torch.hub.load("mhamilton723/FeatUp", 'resnet50').cuda()
hr_feats = upsampler(image_tensor)
lr_feats = upsampler.model(image_tensor)
plot_feats(unnorm(image_tensor)[0], lr_feats[0], hr_feats[0])

and then

torch.onnx.export(
    upsampler,
    (image_tensor, ),
    "upsampler.onnx",
    input_names=["image"],
    output_names=["hr_feats"],
    opset_version=16,
)

I get the following error

iteration over a 0-d tensor

Traceback (most recent call last):
  File "/tmp/ipykernel_793854/2898748745.py", line 20, in <module>
    torch.onnx.export(
  File "~/.local/lib/python3.10/site-packages/torch/onnx/utils.py", line 506, in export
    _export(
  File "~/.local/lib/python3.10/site-packages/torch/onnx/utils.py", line 1548, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "~/.local/lib/python3.10/site-packages/torch/onnx/utils.py", line 1113, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "~/.local/lib/python3.10/site-packages/torch/onnx/utils.py", line 989, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "~/.local/lib/python3.10/site-packages/torch/onnx/utils.py", line 893, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  File "~/.local/lib/python3.10/site-packages/torch/jit/_trace.py", line 1268, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "~/.local/lib/python3.10/site-packages/torch/jit/_trace.py", line 127, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "~/.local/lib/python3.10/site-packages/torch/jit/_trace.py", line 118, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "~/.cache/torch/hub/mhamilton723_FeatUp_main/hubconf.py", line 19, in forward
    return self.upsampler(self.model(image), image)
  File "~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "~/research-foundational/FeatUp/featup/upsamplers.py", line 283, in forward
    source_2 = self.upsample(source, guidance, self.up1)
  File "~/research-foundational/FeatUp/featup/upsamplers.py", line 279, in upsample
    upsampled = up(source, small_guidance)
  File "~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "~/research-foundational/FeatUp/featup/upsamplers.py", line 262, in forward
    return adaptive_conv_py_simple(hr_source_padded, combined_kernel)
  File "~/research-foundational/FeatUp/featup/upsamplers.py", line 17, in adaptive_conv_py_simple
    patches = torch.nn.Unfold(kernel_size)(input) \
  File "~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "~/.local/lib/python3.10/site-packages/torch/nn/modules/fold.py", line 298, in forward
    return F.unfold(input, self.kernel_size, self.dilation,
  File "~/.local/lib/python3.10/site-packages/torch/nn/functional.py", line 4697, in unfold
    return torch._C._nn.im2col(input, _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride))
  File "~/.local/lib/python3.10/site-packages/torch/nn/modules/utils.py", line 11, in parse
    return tuple(x)
  File "~/.local/lib/python3.10/site-packages/torch/_tensor.py", line 930, in __iter__
    raise TypeError("iteration over a 0-d tensor")
TypeError: iteration over a 0-d tensor