dlsyscourse / hw3

1 stars 20 forks source link

Wrong strides in `ndarray.py`'s `__matmul__` #4

Closed tiendung closed 2 years ago

tiendung commented 2 years ago

https://github.com/dlsyscourse/hw3/blob/2403e58f0994852745b9456dc8c451068a35c8a9/python/needle/backend_ndarray/ndarray.py#L494

This should be (a.shape[1] * tile, tile, a.shape[1], 1),)