mhamilton723 / FeatUp

Official code for "FeatUp: A Model-Agnostic Frameworkfor Features at Any Resolution" ICLR 2024
MIT License
1.29k stars 69 forks source link

Exporting upsamplers to ONNX #4

Open opassos opened 3 months ago

opassos commented 3 months ago

I was trying to export those models to ONNX but I am failing. I am getting the following error

SymbolicValueError: Unsupported: ONNX export of operator adaptive_avg_pool2d, output size that are not factor of input size. Please feel free to request support or submit a pull request on PyTorch GitHub:  [Caused by the value '299 defined in (%299 : int[] = prim::ListConstruct(%295, %298), scope: featup.upsamplers.JBUStack::
)' (type 'List[int]') in the TorchScript graph. The containing node has kind 'prim::ListConstruct'.] 

        #0: 295 defined in (%295 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={64}](), scope: featup.upsamplers.JBUStack:: # /content/FeatUp/featup/
    )  (type 'Tensor')
        #1: 298 defined in (%298 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={64}](), scope: featup.upsamplers.JBUStack:: # /content/FeatUp/featup/
    )  (type 'Tensor')
        #0: 299 defined in (%299 : int[] = prim::ListConstruct(%295, %298), scope: featup.upsamplers.JBUStack::
    )  (type 'List[int]')

When running:

    (lr_feats, image_tensor), 

for the DinoV2 example on colab.

opassos commented 3 months ago

I just realized you are using a custom Cuda Kernel. I imagine this won't be easy to export to ONNX right?

Is there a version of the model that uses vanilla Pytorch kernels?

mhamilton723 commented 3 months ago

Hey @opassos yes indeed we use a custom op to make this efficient. There is a way to implement this inefficiently in pytorch

the key operator that we turned into an op - the adaptive conv can be made in pytorch using the unfold operator at a large memory cost

import torch
import torch.nn.functional as F
from torch.autograd import Function
from torch.jit import script


def adaptive_conv_py_simple(input, filters):
    b, c, h1, w1 = input.shape
    b, h2, w2, f1, f2 = filters.shape
    assert f1 == f2
    kernel_size = f1

    t_filters = filters.reshape(b, h2, w2, f1 * f2)

    patches = torch.nn.Unfold(kernel_size)(input) \
        .view((b, c, f1 * f2, h2, w2))

    return torch.einsum('bhwf,bcfhw->bchw', t_filters, patches)
opassos commented 3 months ago

Thank you very much, will try this asap and report back the results.

opassos commented 3 months ago

Exporting the dino models have their own issues (and workarounds) with some ops not being supported so I started trying the resnet50. When running

upsampler = torch.hub.load("mhamilton723/FeatUp", 'resnet50').cuda()
hr_feats = upsampler(image_tensor)
lr_feats = upsampler.model(image_tensor)
plot_feats(unnorm(image_tensor)[0], lr_feats[0], hr_feats[0])

and then

    (image_tensor, ),

I get the following error

iteration over a 0-d tensor

Traceback (most recent call last):
  File "/tmp/ipykernel_793854/", line 20, in <module>
  File "~/.local/lib/python3.10/site-packages/torch/onnx/", line 506, in export
  File "~/.local/lib/python3.10/site-packages/torch/onnx/", line 1548, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "~/.local/lib/python3.10/site-packages/torch/onnx/", line 1113, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "~/.local/lib/python3.10/site-packages/torch/onnx/", line 989, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "~/.local/lib/python3.10/site-packages/torch/onnx/", line 893, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  File "~/.local/lib/python3.10/site-packages/torch/jit/", line 1268, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "~/.local/lib/python3.10/site-packages/torch/nn/modules/", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "~/.local/lib/python3.10/site-packages/torch/jit/", line 127, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "~/.local/lib/python3.10/site-packages/torch/jit/", line 118, in wrapper
  File "~/.local/lib/python3.10/site-packages/torch/nn/modules/", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "~/.local/lib/python3.10/site-packages/torch/nn/modules/", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "~/.cache/torch/hub/mhamilton723_FeatUp_main/", line 19, in forward
    return self.upsampler(self.model(image), image)
  File "~/.local/lib/python3.10/site-packages/torch/nn/modules/", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "~/.local/lib/python3.10/site-packages/torch/nn/modules/", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "~/research-foundational/FeatUp/featup/", line 283, in forward
    source_2 = self.upsample(source, guidance, self.up1)
  File "~/research-foundational/FeatUp/featup/", line 279, in upsample
    upsampled = up(source, small_guidance)
  File "~/.local/lib/python3.10/site-packages/torch/nn/modules/", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "~/.local/lib/python3.10/site-packages/torch/nn/modules/", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "~/research-foundational/FeatUp/featup/", line 262, in forward
    return adaptive_conv_py_simple(hr_source_padded, combined_kernel)
  File "~/research-foundational/FeatUp/featup/", line 17, in adaptive_conv_py_simple
    patches = torch.nn.Unfold(kernel_size)(input) \
  File "~/.local/lib/python3.10/site-packages/torch/nn/modules/", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "~/.local/lib/python3.10/site-packages/torch/nn/modules/", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "~/.local/lib/python3.10/site-packages/torch/nn/modules/", line 298, in forward
    return F.unfold(input, self.kernel_size, self.dilation,
  File "~/.local/lib/python3.10/site-packages/torch/nn/", line 4697, in unfold
    return torch._C._nn.im2col(input, _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride))
  File "~/.local/lib/python3.10/site-packages/torch/nn/modules/", line 11, in parse
    return tuple(x)
  File "~/.local/lib/python3.10/site-packages/torch/", line 930, in __iter__
    raise TypeError("iteration over a 0-d tensor")
TypeError: iteration over a 0-d tensor