Following the same approach as with batched sparse matrix multiplication, a sparse block diagonal matrix is used to perform the calculation over the batch elements.
Unit tests updated with pytest framework.
The previous work around due to issue GH-88890 has now been resolved for pytorch 2.0 and this is reflected in the code.
Following the same approach as with batched sparse matrix multiplication, a sparse block diagonal matrix is used to perform the calculation over the batch elements.
Unit tests updated with pytest framework.
The previous work around due to issue GH-88890 has now been resolved for pytorch 2.0 and this is reflected in the code.