thouis / numpy-trac-migration

numpy Trac to github issues migration
2 stars 3 forks source link

BLAS matrix product (dot) never used for ndim > 2 (tensordot does not use BLAS) (Trac #2163) #5955

Open numpy-gitbot opened 11 years ago

numpy-gitbot commented 11 years ago

Original ticket http://projects.scipy.org/numpy/ticket/2163 on 2012-06-15 by trac user thatistosay, assigned to unknown.

dotblas_matrixproduct() contains the comment "This function doesn't handle dimensions greater than 2" and calls PyArray_MatrixProduct2() for these cases. This means BLAS is never used for calls to dot() with arguments of ndim>2!! In more detail...

If I want to contract a pair of tensor indices for ndim=3 such that

A = np.rand(d,D,D); B = np.rand(d,D,D)
AB[s,i,t,j] == sum(A[s,i,:], B[t,:,j])

then currently, although it can be done in a single line

res = sp.dot(A,B)

it can often be done much faster with explicit (python!) loops

res = np.zeros((d,d,D,D))
for s in xrange(d):
    for t in xrange(d):
        np.dot(A[s], B[t], out=res[s,t])
res = np.rollaxis(res, 2, 1)

..assuming dot() is using optimized BLAS for ndim=2, and the dimensions are large enough so that calling BLAS is worth it. In general, reproducing the behaviour of dot() for ndim>2 is just a matter of calling GEMM in loops as above and then calling rollaxis() once.

I therefore propose doing this within _blasdot.c as far as possible (to eliminate the use of python loops) so that ndim>2 dot(), and tensordot(), can benefit from BLAS.

Some comparisons of the two methods above (attached script):

dtype=complex128
AB[s,i,t,j] = sum(A[s,i,:], B[t,:,j])

A.shape = (16, 512, 512); B.shape = (16, 512, 512)
looping over 2D dot() vs. 3D dot(): 24% (about 4 times faster)

A.shape = (20, 64, 64); B.shape = (20, 64, 64)
looping over 2D dot() vs. 3D dot(): 35%

A.shape = (20, 48, 32); B.shape = (20, 32, 48)
looping over 2D dot() vs. 3D dot(): 45%

A.shape = (32, 32, 16); B.shape = (32, 16, 32)
looping over 2D dot() vs. 3D dot(): 82%

A.shape = (64, 10, 8); B.shape = (64, 8, 10)
looping over 2D dot() vs. 3D dot(): 158% (slow python loops..)

(this was on a 4-core i7 system using ATLAS under heavy load)

numpy-gitbot commented 11 years ago

Attachment added by trac user thatistosay on 2012-06-15: benchmat.py