tensorflow / tfjs

A WebGL accelerated JavaScript library for training and deploying ML models.
https://js.tensorflow.org
Apache License 2.0
18.46k stars 1.92k forks source link

Dispatch size exceeds WebGPU limits in Y or Z dimension #8373

Open arcman7 opened 1 month ago

arcman7 commented 1 month ago

Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template

System information

Describe the current behavior

When attempting to perform matrix multiplication using tf.matMul(r, s) where tensor r has shape [1493284, 3, 3] and tensor s has shape [1493284, 3, 3], an error is thrown: "Dispatch size exceeds WebGPU limits in Y or Z dimension."

The error occurs in the reshapeDispatch function of the WebGPU backend when it tries to handle a dispatch shape of [1, 1, 1493284].

Describe the expected behavior

The matrix multiplication should be performed successfully without throwing an error related to dispatch size limits.

Standalone code to reproduce the issue

const r = tf.zeros([1493284, 3, 3]);
const s = tf.zeros([1493284, 3, 3]);
const l = tf.matMul(r, s);

Other info / logs The error is triggered by this assertion in the WebGPU backend code:

tf.util.assert(dispatch[0] > MAX_COMPUTE_PER_DIMENSION_DISPATCH_SIZE &&
            layout.y === undefined && layout.z === undefined, function () { return 'Dispatch size exceeds WebGPU limits in Y or Z dimension.'; });

This occurs in the reshapeDispatch function, which is called to handle the dispatch shape [1, 1, 1493284] generated by the matrix multiplication operation.

The full implementation of the reshapeDispatch function can be found at: https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-backend-webgpu/dist/tf-backend-webgpu.js (around line 1252)

arcman7 commented 1 month ago

Note: This same matrix multiplication operation works fine using the webgl backend.

shmishra99 commented 1 month ago

Hi @arcman7 ,

I've successfully replicated the issue you're experiencing and am seeing the same error.

image

I'm currently investigating further and will update you with my findings as soon as possible.

Thank You!!

arcman7 commented 1 month ago

@shmishra99 Idk if this is useful to you or not, but in my own webGPU helper code I use something like this:

interface AdapterLimits {
  maxComputeInvocationsPerWorkgroup: number;
  maxComputeWorkgroupSizeX: number;
  maxComputeWorkgroupSizeY: number;
  maxComputeWorkgroupSizeZ: number;
  maxComputeWorkgroupsPerDimension: number;
}

interface WorkgroupConfig {
  workgroupSize: [number, number, number];
  dispatchSize: [number, number, number];
}

export function calculateWorkgroups(totalThreads: number, limits: AdapterLimits): WorkgroupConfig {
  const maxSize = Math.min(
    limits.maxComputeInvocationsPerWorkgroup,
    totalThreads
  );

  let x = Math.floor(Math.sqrt(maxSize));
  let y = Math.floor(maxSize / x);
  let z = 1;

  // Adjust x and y if they exceed their respective limits
  if (x > limits.maxComputeWorkgroupSizeX) {
    x = limits.maxComputeWorkgroupSizeX;
    y = Math.floor(maxSize / x);
  }

  if (y > limits.maxComputeWorkgroupSizeY) {
    y = limits.maxComputeWorkgroupSizeY;
    x = Math.floor(maxSize / y);
  }

  // If x or y still exceed their limits, use the z dimension
  if (x > limits.maxComputeWorkgroupSizeX || y > limits.maxComputeWorkgroupSizeY) {
    x = Math.min(x, limits.maxComputeWorkgroupSizeX);
    y = Math.min(y, limits.maxComputeWorkgroupSizeY);
    z = Math.floor(maxSize / (x * y));
    z = Math.min(z, limits.maxComputeWorkgroupSizeZ);
  }

  // Ensure the total doesn't exceed maxComputeInvocationsPerWorkgroup
  const total = x * y * z;
  if (total > limits.maxComputeInvocationsPerWorkgroup) {
    const scale = Math.cbrt(limits.maxComputeInvocationsPerWorkgroup / total);
    x = Math.floor(x * scale);
    y = Math.floor(y * scale);
    z = Math.floor(z * scale);
  }

  // Calculate the number of workgroups to dispatch
  const dispatchX = Math.ceil(totalThreads / (x * y * z));
  const dispatchY = 1;
  const dispatchZ = 1;

  return {
    workgroupSize: [x, y, z],
    dispatchSize: [dispatchX, dispatchY, dispatchZ]
  };
}

But then you still have to maintain properly calculating flat global ids wherever it applies.

gaikwadrahul8 commented 1 month ago

Hi, @arcman7

I apologize for the delayed response and it seems like there is maximum limit for maxComputeWorkgroupsPerDimension is 65535 please refer this official documentation which is getting assigned to MAX_COMPUTE_PER_DIMENSION_DISPATCH_SIZE and below assertion checks if the dispatch size in the X dimension exceeds the maximum limit and if the layout is 1D or 2D (i.e. no Y or Z dimensions). If both conditions are true an error is thrown indicating that the dispatch size exceeds WebGPU limits.

tf.util.assert(dispatch[0] > MAX_COMPUTE_PER_DIMENSION_DISPATCH_SIZE &&
            layout.y === undefined && layout.z === undefined, function () { return 'Dispatch size exceeds WebGPU limits in Y or Z dimension.'; });

Case 01 : When I tried value less than or equal to 65535 it's working as expected I've added output log below for your reference

image

Case 02 : When I tried value greater than to 65535 it's throwing the same error message Dispatch size exceeds WebGPU limits in Y or Z dimension.I've added output log below for your reference

image

It seems like you'll have to use value less than or equal to 65535 due to maximum limit, If I have missed something here please let me know.

Thank you for your cooperation and patience.

arcman7 commented 1 month ago

Hi @gaikwadrahul8 it is my understanding that maxComputeWorkgroupsPerDimension is the maximum in one dimension. The code I posted above is to serve as an example on how to deal with dispatching much larger computations as webGPU allows for. In the case where I have 1493284 rows of data to process, I would expect tfjs to internally batch the calls like so:

const N = 1493284;
dispatchWorkgroups(
  maxComputeWorkgroupsPerDimension, // x
  Math.ceil( N / maxComputeWorkgroupsPerDimension), // y
  1 // z
);

You could just as well first divide N by the size of the workgroup (256 is the max) when N is greater than maxComputeWorkgroupsPerDimension.