sonos / tract

Tiny, no-nonsense, self-contained, Tensorflow and ONNX inference
Other
2.24k stars 214 forks source link

Gather analysis seems to delete singleton dimensions now #1193

Closed alexander-camuto closed 1 year ago

alexander-camuto commented 1 year ago

Apologies for another Gather issue but it appears that #1191, though it fixes #1190 and #1187 now deletes singleton dimensions upon analysis for other edge cases of models that our users have just sent over.

as an example consider the following trace generated by running tract network.onnx dump --io-long:


┏ 0 Source input
┃   * output fact #0: batch_size,4,F32 >2/0 MODEL INPUT #0 
┣┻ 2 Gather Gather_0
┃   * input fact  #0: 0/0> batch_size,4,F32
┃     input fact  #1: 1/0> 6,I64 3, 2, 3, 2, 0, 2
┃   * output fact #0: batch_size,6,F32 >3/0  onnx::Cast_8
┣ 3 Cast Cast_1
┃   * input fact  #0: 2/0> batch_size,6,F32
┃   * output fact #0: batch_size,6,F64 >5/1  onnx::Less_9
┣┻ 5 Less Less_2
┃   * input fact  #0: 4/0> 1,6,F64 0.800000011920929, 4.8500001430511475, 1.699999988079071, 4.950000047683716, 6.049999952316284, 5.049999952316284
┃     input fact  #1: 3/0> batch_size,6,F64
┃   * output fact #0: batch_size,6,Bool >6/0  onnx::Cast_10
┣ 6 Cast Cast_3
┃   * input fact  #0: 5/0> batch_size,6,Bool
┃   * output fact #0: batch_size,6,I8 >8/0 >13/0  onnx::MatMul_11
┣┓  
┃┣┻ 8 EinSum MatMul_4
┃┃   * input fact  #0: 6/0> batch_size,6,I8
┃┃     input fact  #1: 7/0> 6,7,I8 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1...
┃┃   * output fact #0: batch_size,7,I8 >9/0 >14/0  onnx::Cast_13
┃┣┓ 
┃┃┣ 9 Cast Cast_6
┃┃┃   * input fact  #0: 8/0> batch_size,7,I8
┃┃┃   * output fact #0: batch_size,7,TDim >11/0  onnx::Equal_16
┃┃┣┻ 11 Equals Equal_7
┃┃┃   * input fact  #0: 9/0> batch_size,7,TDim
┃┃┃     input fact  #1: 10/0> 1,7,TDim 0, 1, 2, 3, 4, 4, 3
┃┃┃   * output fact #0: batch_size,7,Bool >15/0  onnx::And_17
┗━┓  
┃┃┣┻ 13 EinSum MatMul_8
┃┃┃   * input fact  #0: 6/0> batch_size,6,I8
┃┃┃     input fact  #1: 12/0> 6,7,I8 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1...
┃┃┃   * output fact #0: batch_size,7,I8 >14/1  onnx::Equal_19
┗┓┃ 
┃┣┻ 14 Equals Equal_9
┃┃   * input fact  #0: 8/0> batch_size,7,I8
┃┃     input fact  #1: 13/0> batch_size,7,I8
┃┃   * output fact #0: batch_size,7,Bool >15/1  onnx::And_20
┣┻ 15 And And_10
┃   * input fact  #0: 11/0> batch_size,7,Bool
┃     input fact  #1: 14/0> batch_size,7,Bool
┃   * output fact #0: batch_size,7,Bool >16/0  onnx::Cast_21
┣ 16 Cast Cast_11
┃   * input fact  #0: 15/0> batch_size,7,Bool
┃   * output fact #0: batch_size,7,I32 >17/0  onnx::ArgMax_22
┣ 17 Reduce<ArgMax(false)> ArgMax_12
┃   * input fact  #0: 16/0> batch_size,7,I32
┃   * output fact #0: batch_size,1,I64 >18/0  
┣ 18 RmAxis ArgMax_12-dispose-dims-1
┃   * input fact  #0: 17/0> batch_size,1,I64
┃   * output fact #0: batch_size,I64 >20/1  onnx::Gather_23
┣┻ 20 Gather Gather_13
┃   * input fact  #0: 19/0> 7,1,3,F64 38, 0, 0, 0, 36, 0, 0, 2, 0, 0, 0, 1...
┃     input fact  #1: 18/0> batch_size,I64
┃   * output fact #0: batch_size,1,3,F64 >22/0  onnx::Gather_24
┣┻ 22 Gather Gather_15
┃   * input fact  #0: 20/0> batch_size,1,3,F64
┃     input fact  #1: 21/0> ,I64 0
┃   * output fact #0: batch_size,3,F64 >23/0  onnx::ArgMax_31
┣ 23 Reduce<ArgMax(false)> ArgMax_21
┃   * input fact  #0: 22/0> batch_size,3,F64
┃   * output fact #0: batch_size,1,I64 >24/0  
┣ 24 RmAxis ArgMax_21-dispose-dims-1
┃   * input fact  #0: 23/0> batch_size,1,I64
┃   * output fact #0: batch_size,I64 >26/1  onnx::Gather_32
┣┻ 26 Gather Gather_22
    * input fact  #0: 25/0> 3,I64 0, 1, 2
      input fact  #1: 24/0> batch_size,I64
    * output fact #0: batch_size,I64  MODEL OUTPUT #0 output

If we focus in on node 22:

┣┻ 22 Gather Gather_15
┃   * input fact  #0: 20/0> batch_size,1,3,F64
┃     input fact  #1: 21/0> ,I64 0
┃   * output fact #0: batch_size,3,F64 >23/0  onnx::ArgMax_31

We find that the gather op, which reduces over axis 1 of the input of size batch_size,1,3 generates an output of size batch_size,3 -- which is not expected behaviour.

I've attached one of the edge case models for which this happens:

network.onnx.zip

kali commented 1 year ago

Mmm... Is it really ? I'm obviously strugging a bit making sense of the ONNX spec.

https://github.com/onnx/onnx/blob/main/docs/Operators.md#gather

So in this case, input is or rank r=3. Indices is of rank q=0, so output should be of rank 2, right ?

alexander-camuto commented 1 year ago

I see what you mean.

I've seen rank be used pretty loosely and passing singleton (non-matrix) indices like this is not even defined behaviour in torch or tf:


>>> import torch
>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[1, 1],
        [4, 3]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]])).shape
torch.Size([2, 2])
>>> torch.gather(t, 1, torch.tensor([0])).shape
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Index tensor must have the same number of dimensions as input tensor

I actually think your interpretation here is correct as the subsequent shapes / axes implied by the onnx graph don't make sense without the deleted axis.

This is the node in question:

image image

Which is then sliced along axis 1:

image image

This wouldn't really make sense for an input of size [1,1,3] but does for an input of size [1,3]. The latter happens if we interpret the singleton index as being of rank 0 -- and delete the gathered axis accordingly.

Thank you for clarifying !

kali commented 1 year ago

Yeah, there is a lot of abuse in the way trivial dimensions are sometimes discarded or added all over the place. I think the current implementation matches the spec.

But if we have an important source of such abuse somewhere in the ecosystem (like torch+torch-to-nnef), we may need to go against the spec and implement workarounds (if possible at all). Do you know what generated these models that look invalid ?

alexander-camuto commented 1 year ago

Yeah the sk2torch python package applied to sklearn decision trees. So it may be particular to that package

kali commented 1 year ago

I'm strangely not surprised, sklearn decision trees are a repeating offender in axes abuse. They have pre-everything-is-a-tensor era behaviours that are often pretty inconsistent and need custom code to figure out. Have you looked into the problem enough to suggest workarounds strategies that could accommodate what they are generating without going off-spec ?

alexander-camuto commented 1 year ago

I think I'll recommend the hummingbird-ml package for the conversion -- which seems to produce less/no weird axes ops.

Thank you for all your help