dlsyscourse / hw3

1 stars 16 forks source link

ndarray.__matmul__: bugfix in tiling #8

Closed navalnica closed 11 months ago

navalnica commented 1 year ago

As mentioned in #4, the following line in ndarray.__matmul__.tile():

(a.shape[1] * tile, tile, self.shape[1], 1),

must be changed to:

(a.shape[1] * tile, tile, a.shape[1], 1),

Otherwise tile(other.compact(), t) will return a matrix with wrong strides.

Adding (8, 16, 8) to matmul_dims params of test_ndarray.test_matmul() test function allows to demonstrate the issue - the test case fails if we use tile() function with wrong strides calculation. After I've fixed tile() function, this newly added test passed successfully.

By the way, we can't catch this error using test_ndarray.test_matmul_tiled() test function because it doesn't use the tile() function and creates input matrices already in shapes required for nd.cpu().matmul_tiled() function