Closed mayjs closed 11 months ago
WebGPU (and hence wgpu
) impose limits on the number of workgroups (see here). The error occurs when invoking dispatch
with workgroup counts that exceed these limits.
Many ops in wonnx
determine the workgroup count dynamically as to stay within these limits. I am not sure if you are doing this in your op implementation but I would suggest looking at other ops code in compiler.rs
.
I see, thanks for the link, I missed that part in the compiler source code.
But this issue is not caused by my current implementation, it's the already existing Concat
implementation.
I think it could be possible to distribute the workgroups differently by changing @workgroup_size(256,1,1)
to @workgroup_size(16,16,1)
and calculating the workgroup counts accordingly (since the limits are applied per dimension, so getting each dimension to a lower value should be helpful here).
I'm going to give that a go and try to refactor the Concat
implementation.
:+1:
Most ops call workgroup_size
to determine the appropriate workgroup size dynamically. It appears the implementation of Concat
(actually one of the first ops to be implemented in wonnx) does not do that, so overflowing the limits is expected. Happy to review a PR that fixes this (shouldn't be that difficult!).
Sure, I'll take a look at other operations and try to port the logic to Concat
:+1:
I also got my network to run using the approach I described above - just need to actually implement a proper shader for ConvTranspose
now.
I think I'm also going to add a unit test for Concat
as it seems like there is no test for that right now.
Describe the bug I get the error
ComputeLimitExceeded("X threads", 96800, 65535)
when loading a network that concatenates two tensors. The inputs are two tensors with shape [1, 64, 440, 440].To Reproduce Load a network that has a similar input size. My ONNX sample input wouldn't be helpful here, because it uses a ConvTranspose operation which I am working on adding to wonnx in my fork.
Expected behavior The concatenation should work.
Screenshots I added some output to print the input and output shapes in the compiler:
Is this a general limitation and an indicator that the network might be too complex for wonnx or would it make sense to consider bumping the workgroup size or maybe making it dynamic depending on the input size?