autonomousvision / unimatch

[TPAMI'23] Unifying Flow, Stereo and Depth Estimation
https://haofeixu.github.io/unimatch/
MIT License
980 stars 102 forks source link

Add argument types to be able to use torch JIT #54

Open AdrianEddy opened 3 months ago

AdrianEddy commented 3 months ago

This PR adds types to function signatures to be able to use torch.jit or torch.onnx.export(). I also had to convert some functions to modules

The code should be equivalent to the previous one, I verified that with inference (I didn't test training though)

It's easiest to review this without whitespace diff

Related to #29

ylab604 commented 2 months ago

@AdrianEddy Thank you for great works! I try to inference with your code. but ,

Traceback (most recent call last): File "main_stereo.py", line 612, in main(args) File "main_stereo.py", line 331, in main inference_stereo(model_without_ddp, File "anaconda3/envs/torch/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, kwargs) File "unimatch/evaluate_stereo.py", line 799, in inference_stereo pred_disp = model(left, right, File "anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "unimatch/unimatch/unimatch.py", line 190, in forward feature0, feature1 = self.transformer(feature0, feature1, File "anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "unimatch/unimatch/transformer.py", line 272, in forward shifted_window_attn_mask_1d = self.generate_shift_window_attn_mask_1d( File "anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "unimatch/unimatch/utils.py", line 207, in forward mask_windows = window_partition_1d(img_mask, window_size_w) # nW, window_size, 1 File "unimatch/unimatch/utils.py", line 193, in window_partition_1d B, W, C = x.shape ValueError: too many values to unpack (expected 3)

AdrianEddy commented 2 months ago

@ylab604 Please check now

ylab604 commented 2 months ago

I see i did change the mask function when i check your code(yesterday). But, important thing is that onnx graph(netron)is not normal campare with pinto0309

ylab604 commented 2 months ago

Anyway thank you for your kindness. And i will also update the result of excution

AdrianEddy commented 2 months ago

What do you mean it's not normal? What's weird about it?

ylab604 commented 2 months ago

What do you mean it's not normal? What's weird about it?

This means that if converted to onnx or jit, the inference output will be different from the original torch model.