Open opassos opened 3 months ago
I just realized you are using a custom Cuda Kernel. I imagine this won't be easy to export to ONNX right?
Is there a version of the model that uses vanilla Pytorch kernels?
Hey @opassos yes indeed we use a custom op to make this efficient. There is a way to implement this inefficiently in pytorch
the key operator that we turned into an op - the adaptive conv can be made in pytorch using the unfold operator at a large memory cost
import torch
import torch.nn.functional as F
from torch.autograd import Function
from torch.jit import script
torch.manual_seed(42)
def adaptive_conv_py_simple(input, filters):
b, c, h1, w1 = input.shape
b, h2, w2, f1, f2 = filters.shape
assert f1 == f2
kernel_size = f1
t_filters = filters.reshape(b, h2, w2, f1 * f2)
patches = torch.nn.Unfold(kernel_size)(input) \
.view((b, c, f1 * f2, h2, w2))
return torch.einsum('bhwf,bcfhw->bchw', t_filters, patches)
Thank you very much, will try this asap and report back the results.
Exporting the dino models have their own issues (and workarounds) with some ops not being supported so I started trying the resnet50. When running
upsampler = torch.hub.load("mhamilton723/FeatUp", 'resnet50').cuda()
hr_feats = upsampler(image_tensor)
lr_feats = upsampler.model(image_tensor)
plot_feats(unnorm(image_tensor)[0], lr_feats[0], hr_feats[0])
and then
torch.onnx.export(
upsampler,
(image_tensor, ),
"upsampler.onnx",
input_names=["image"],
output_names=["hr_feats"],
opset_version=16,
)
I get the following error
iteration over a 0-d tensor
Traceback (most recent call last):
File "/tmp/ipykernel_793854/2898748745.py", line 20, in <module>
torch.onnx.export(
File "~/.local/lib/python3.10/site-packages/torch/onnx/utils.py", line 506, in export
_export(
File "~/.local/lib/python3.10/site-packages/torch/onnx/utils.py", line 1548, in _export
graph, params_dict, torch_out = _model_to_graph(
File "~/.local/lib/python3.10/site-packages/torch/onnx/utils.py", line 1113, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)
File "~/.local/lib/python3.10/site-packages/torch/onnx/utils.py", line 989, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "~/.local/lib/python3.10/site-packages/torch/onnx/utils.py", line 893, in _trace_and_get_graph_from_model
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
File "~/.local/lib/python3.10/site-packages/torch/jit/_trace.py", line 1268, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "~/.local/lib/python3.10/site-packages/torch/jit/_trace.py", line 127, in forward
graph, out = torch._C._create_graph_by_tracing(
File "~/.local/lib/python3.10/site-packages/torch/jit/_trace.py", line 118, in wrapper
outs.append(self.inner(*trace_inputs))
File "~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
result = self.forward(*input, **kwargs)
File "~/.cache/torch/hub/mhamilton723_FeatUp_main/hubconf.py", line 19, in forward
return self.upsampler(self.model(image), image)
File "~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
result = self.forward(*input, **kwargs)
File "~/research-foundational/FeatUp/featup/upsamplers.py", line 283, in forward
source_2 = self.upsample(source, guidance, self.up1)
File "~/research-foundational/FeatUp/featup/upsamplers.py", line 279, in upsample
upsampled = up(source, small_guidance)
File "~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
result = self.forward(*input, **kwargs)
File "~/research-foundational/FeatUp/featup/upsamplers.py", line 262, in forward
return adaptive_conv_py_simple(hr_source_padded, combined_kernel)
File "~/research-foundational/FeatUp/featup/upsamplers.py", line 17, in adaptive_conv_py_simple
patches = torch.nn.Unfold(kernel_size)(input) \
File "~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
result = self.forward(*input, **kwargs)
File "~/.local/lib/python3.10/site-packages/torch/nn/modules/fold.py", line 298, in forward
return F.unfold(input, self.kernel_size, self.dilation,
File "~/.local/lib/python3.10/site-packages/torch/nn/functional.py", line 4697, in unfold
return torch._C._nn.im2col(input, _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride))
File "~/.local/lib/python3.10/site-packages/torch/nn/modules/utils.py", line 11, in parse
return tuple(x)
File "~/.local/lib/python3.10/site-packages/torch/_tensor.py", line 930, in __iter__
raise TypeError("iteration over a 0-d tensor")
TypeError: iteration over a 0-d tensor
I was trying to export those models to ONNX but I am failing. I am getting the following error
When running:
for the DinoV2 example on colab.