AxisCommunications / onnx-to-keras

Convert onnx models exported from pytorch to tensorflow keras models with focus on performace and highleve compatibility.
MIT License
25 stars 13 forks source link

Implement concat for OnnxTensor #17

Closed xsacha closed 3 years ago

xsacha commented 3 years ago

Fixes #19

hakanardo commented 3 years ago

Thanx! Please consider adding a test in test_onnx2keras.py aswell.

xsacha commented 3 years ago

I tried to create a test for it but wasn't sure how to get the OnnxTensor type through.

hakanardo commented 3 years ago

Concating OnnxTensors will not always resut in an InterleavedImageBatch. Do you have some (preferable minimal) example model you could share that triggers the issue? (That's what we'll need for the test anyway).

xsacha commented 3 years ago

I'm not entirely sure how it produces that OnnxTensor input. Anything that produces that would be the test.

The line in PyTorch looks like:

classifications = torch.cat([self.ClassHead[i](feature).permute(0,2,3,1).reshape(batch, -1, 2) for i, feature in enumerate(features)],dim=1)

Is it due to the reshape maybe? Someone else managed to get the same style of cat judging by issue #19

hakanardo commented 3 years ago

Superseeded by Pr#24