Unity-Technologies / barracuda-release

Other
568 stars 78 forks source link

ONNX - ArgMax/ArgMin support #111

Closed harmoniqpunk closed 3 years ago

harmoniqpunk commented 4 years ago

Hi,

I'm trying to import an ONNX model but I can not do this because Barracuda doesn't support ArgMax operator.

See this issue

Can you please add ArgMax operator or at least point me out a bit to learn how to add myself and contribute?

Thank you

AlexRibard commented 4 years ago

Hi @nauutilus we do, just not as importing from ONNX The way to do it is

Tensor X =...
int[] argmax = X.ArgMax();

So for now you'll need to remove it out of the model

AlexRibard commented 4 years ago

image can you do this in C#? GatherElements is also not supported unfortunety.

harmoniqpunk commented 4 years ago

Hi Alex,

So you suggest me to remove it completly from the model and return earlier in order to perform the argmax in C#? Actually is a good ideaa but I didn't knew if I can perform argmax and gather in C#. Do you know if I have a corespondedn PyTorch gather function in Barracuda Inference Engine framework? If so, can it be perfomed on GPU?

Thank you

AlexRibard commented 4 years ago

A few things:

Porting ArgMax to GPU is easy, however there is a lot of complicated index manipulation at the end of the network which will require some work to make it work fully. For me it looks like you are trying to gather the argmax values on a few axis and concatenate all into one tensor. This looks offly like TopK https://github.com/onnx/onnx/blob/master/docs/Operators.md#Topk which we support :) (only on CPU and might not work with 5D inputs however, but I have some ideas on how to make it work). Do you think this would work?

harmoniqpunk commented 4 years ago

I was ready to directly manipulate arrays even if is not elegant at this point.

I managed to get rid of that rank 5 tensor and now I'm dealing only with a rank 3 tensor so in this perspective I'm fine.

But I need the model on GPU because I need low latency. Is there any list of supported operators on GPU like this one?

AlexRibard commented 4 years ago

We currently do not have such a list. But it's a good idea, we should definitely add that :)

harmoniqpunk commented 4 years ago

Ok. Until you publish that list can you help me by teling me if my model have any other operators beside ArgMax and GatherElements that have no GPU support? Also direct manipulation off arrays can be done on GPU?

https://github.com/nauutilus/InterHand2.6M/releases/download/0.0.1/interhand.onnx

Thank you

AlexRibard commented 4 years ago

Yes. I'll do that. GatherElements is not supported be it on CPU or GPU atm. For tensor manipulation it will require Compute shaders. In current state of things, it is still not fully exposed to users, but I can make you a small project demo if you wish.

harmoniqpunk commented 4 years ago

This would help me Alex to have an example of how can I use Compute Shaders to perform operations on GPU. I never did this. Does this use Vulkan or OpenCL?

If would help you for this example I have truncated the model so the output is a rank 3 tensor (batch, joints, 3d_coordinates)

Here is the truncated ONNX model: https://github.com/nauutilus/InterHand2.6M/releases/download/0.0.3/interhand.onnx

In PyTorch I would take this rank 3 tensor and apply these operations:

joint_heatmap_out is the rank 3 tensor (batch, joints, 3d_coordinates), the current output of truncated model

           out = {}

            idx = torch.argmax(joint_heatmap_out, dim=2, keepdim=True)
            idx_z = idx // (cfg.output_hm_shape[1]*cfg.output_hm_shape[2])
            idx_y = idx % (cfg.output_hm_shape[1]*cfg.output_hm_shape[2]) // cfg.output_hm_shape[2]
            idx_x = idx % (cfg.output_hm_shape[1]*cfg.output_hm_shape[2]) % cfg.output_hm_shape[2]

            joint_z = torch.gather(joint_heatmap_out, dim=2, index=idx_z)
            joint_y = torch.gather(joint_heatmap_out, dim=2, index=idx_y)
            joint_x = torch.gather(joint_heatmap_out, dim=2, index=idx_x)

            joint_coord_out = torch.cat((joint_x, joint_y, joint_z),2).float()
            out['joint_coord'] = joint_coord_out
            out['rel_root_depth'] = rel_root_depth_out
            out['hand_type'] = hand_type
            out['inv_trans'] = inv_trans
            out['target_joint'] = target_joint_coord
            out['joint_valid'] = joint_valid
            out['hand_type_valid'] = hand_type_valid
            return out

Just a brief overview to understand what I want to achieve. This model would want to control the armature of a humanoid robotic hand from a single RGB image. The input is an image and the output of the model should be the 3D coordinates of each joint. I would take the input from a 16 FPS webcam stream and feed into the Barracuda worker input. I take the output from the model as tensor rank 3, perform indexing on a compute shaders as you suggest and then control the armature joints. So I need to do inference on MacBook GPU at 16 FPS at the lowest possible latency.

Would be great if you can show me an example of how I can use Compute shaders. I'm not an experimented Unity dev but I have experience in low-level graphics computing like Vulkan or OpenCL.

Thank you

AlexRibard commented 4 years ago

Ok! thanks for explaining that to me :)

Compute Shaders are simpler version of OpenCL kernels so you would probably have no issues with them. But I'll write you some code to dispatch it and manipulate the tensor

AlexRibard commented 4 years ago

@nauutilus here is the promised code. TensorManipulationInComputeShader.zip On a side note we are adding ArgMax support. If the compute shader doesn't solve your problem, maybe you can try replacing your logic using Max -> Where?

Let me know if it helps.

harmoniqpunk commented 4 years ago

Thank you @AlexRibard. I just watch on your sample code and I got it.

If the compute shader doesn't solve your problem, maybe you can try replacing your logic using Max -> Where?

This was what I was ready to go but I like the shader solution.

mantasp commented 3 years ago

ArgMax support landed in Barracuda 1.2.1