webonnx / wonnx

A WebGPU-accelerated ONNX inference run-time written 100% in Rust, ready for native and the web
Other
1.54k stars 54 forks source link

ComputeLimitExceeded("X threads", 96800, 65535) #178

Closed mayjs closed 11 months ago

mayjs commented 11 months ago

Describe the bug I get the error ComputeLimitExceeded("X threads", 96800, 65535) when loading a network that concatenates two tensors. The inputs are two tensors with shape [1, 64, 440, 440].

To Reproduce Load a network that has a similar input size. My ONNX sample input wouldn't be helpful here, because it uses a ConvTranspose operation which I am working on adding to wonnx in my fork.

Expected behavior The concatenation should work.

Screenshots I added some output to print the input and output shapes in the compiler:

[2023-07-23T19:54:42Z DEBUG wonnx::compiler] [
        Shape {
            dims: [
                1,
                64,
                440,
                440,
            ],
            data_type: F32,
        },
        Shape {
            dims: [
                1,
                64,
                440,
                440,
            ],
            data_type: F32,
        },
    ]
[2023-07-23T19:54:42Z DEBUG wonnx::compiler] [
        Shape {
            dims: [
                1,
                128,
                440,
                440,
            ],
            data_type: F32,
        },
    ]

Is this a general limitation and an indicator that the network might be too complex for wonnx or would it make sense to consider bumping the workgroup size or maybe making it dynamic depending on the input size?

pixelspark commented 11 months ago

WebGPU (and hence wgpu) impose limits on the number of workgroups (see here). The error occurs when invoking dispatch with workgroup counts that exceed these limits.

Many ops in wonnx determine the workgroup count dynamically as to stay within these limits. I am not sure if you are doing this in your op implementation but I would suggest looking at other ops code in compiler.rs.

mayjs commented 11 months ago

I see, thanks for the link, I missed that part in the compiler source code.

But this issue is not caused by my current implementation, it's the already existing Concat implementation. I think it could be possible to distribute the workgroups differently by changing @workgroup_size(256,1,1) to @workgroup_size(16,16,1) and calculating the workgroup counts accordingly (since the limits are applied per dimension, so getting each dimension to a lower value should be helpful here).

I'm going to give that a go and try to refactor the Concatimplementation.

pixelspark commented 11 months ago

:+1:

Most ops call workgroup_size to determine the appropriate workgroup size dynamically. It appears the implementation of Concat (actually one of the first ops to be implemented in wonnx) does not do that, so overflowing the limits is expected. Happy to review a PR that fixes this (shouldn't be that difficult!).

mayjs commented 11 months ago

Sure, I'll take a look at other operations and try to port the logic to Concat :+1: I also got my network to run using the approach I described above - just need to actually implement a proper shader for ConvTransposenow.

I think I'm also going to add a unit test for Concat as it seems like there is no test for that right now.