Improve julia's expm performance

closed 11 years ago

ViralBShah commented 11 years ago

Improve expm performance by directly calling BLAS/LAPACK routines in expm and using the fewest number of temporaries.


ViralBShah commented 11 years ago

This myexpm implementation tries to reduce the number of temporaries through devectorization. It is about 30% faster.

function myexpm!{T}(A::StridedMatrix{T})
    m, n = size(A)
    if m != n error("myexpm!: Matrix A must be square") end
    if m < 2 return exp(A) end
    ilo, ihi, scale = LAPACK.gebal!('B', A)    # modifies A
    nA   = norm(A, 1)
    I    = convert(Array{T,2}, eye(n))
    ## For sufficiently small nA, use lower order Padé-Approximations
    if (nA <= 2.1)
        if nA > 0.95
            C = [17643225600.,8821612800.,2075673600.,302702400.,
                    30270240.,   2162160.,    110880.,     3960.,
                          90.,         1.]
        elseif nA > 0.25
            C = [17297280.,8648640.,1995840.,277200.,
                    25200.,   1512.,     56.,     1.]
        elseif nA > 0.015
            C = [30240.,15120.,3360.,
                   420.,   30.,   1.]
            C = [120.,60.,12.,1.]
        A2 = A * A
        P  = copy(I)
        U  = C[2] * P
        V  = C[1] * P
        for k in 1:(div(size(C, 1), 2) - 1)
            k2 = 2 * k
            P *= A2
            #U += C[k2 + 2] * P
            #V += C[k2 + 1] * P
            Ck21 = C[k2 + 1]
            Ck22 = C[k2 + 2]
            for i=1:length(P)
                U[i] += Ck22 * P[i]
                V[i] += Ck21 * P[i]
        U  = A * U
        X  = (V - U)\(V + U)
        s  = log2(nA/5.4)               # power of 2 later reversed by squaring
        if s > 0
            si = iceil(s)
            A /= 2^si
        CC = [64764752532480000.,32382376266240000.,7771770303897600.,
               1187353796428800.,  129060195264000.,  10559470521600.,
                   670442572800.,      33522128640.,      1323241920.,
                       40840800.,           960960.,           16380.,
                            182.,                1.]
        A2 = A * A
        A4 = A2 * A2
        A6 = A2 * A4
#         U  = A * (A6 * (CC[14]*A6 + CC[12]*A4 + CC[10]*A2) +
#                   CC[8]*A6 + CC[6]*A4 + CC[4]*A2 + CC[2]*I)
#         V  = A6 * (CC[13]*A6 + CC[11]*A4 + CC[9]*A2) +
#                   CC[7]*A6 + CC[5]*A4 + CC[3]*A2 + CC[1]*I
        P1 = zeros(T, n, n)
        P2 = zeros(T, n, n)
        P3 = zeros(T, n, n)
        P4 = zeros(T, n, n)
        CC14 = CC[14]; CC12 = CC[12]; CC10 = CC[10]
        CC8 = CC[8];   CC6 = CC[6];   CC4 = CC[4];   CC2 = CC[2];   
        CC13 = CC[13]; CC11 = CC[11]; CC9 = CC[9]   
        CC7 = CC[7];   CC5 = CC[5];   CC3 = CC[3];   CC1 = CC[1]
        for i=1:length(I)
            P1[i] += CC14*A6[i] + CC12*A4[i] + CC10*A2[i]
            P2[i] += CC8*A6[i] + CC6*A4[i] + CC4*A2[i] + CC2*I[i]
            P3[i] += CC13*A6[i] + CC11*A4[i] + CC9*A2[i]
            P4[i] += CC7*A6[i] + CC5*A4[i] + CC3*A2[i] + CC1*I[i]
        U = A * (A6*P1 + P2)
        V = A6*P3 + P4

        X  = (V-U)\(V+U)

        if s > 0            # squaring to reverse dividing by power of 2
            for t in 1:si X *= X end
                                        # Undo the balancing
    doscale = false                     # check if rescaling is needed
    for i = ilo:ihi
        if scale[i] != 1.
            doscale = true
    if doscale
        for j = ilo:ihi
            scj = scale[j]
            if scj != 1.                # is this overkill?
                for i = ilo:ihi
                    X[i,j] *= scale[i]/scj
                for i = ilo:ihi
                    X[i,j] *= scale[i]
    if ilo > 1       # apply lower permutations in reverse order
        for j in (ilo-1):1:-1 rcswap!(j, int(scale[j]), X) end
    if ihi < n       # apply upper permutations in forward order
        for j in (ihi+1):n    rcswap!(j, int(scale[j]), X) end
    convert(Matrix{T}, X)

## Swap rows j and jp and columns j and jp in X
function rcswap!{T<:Number}(j::Int, jp::Int, X::StridedMatrix{T})
    for k in 1:size(X, 2)
        tmp     = X[k,j]
        X[k,j]  = X[k,jp]
        X[k,jp] = tmp
        tmp     = X[j,k]
        X[j,k]  = X[jp,k]
        X[jp,k] = tmp

# Matrix exponential
myexpm{T<:Union(Float32,Float64,Complex64,Complex128)}(A::StridedMatrix{T}) = myexpm!(copy(A))
myexpm{T<:Integer}(A::StridedMatrix{T}) = myexpm!(float(A))
JeffBezanson commented 11 years ago

Any reason not to commit this?

ViralBShah commented 11 years ago

I am a bit hesitant because it makes the code a bit too ugly! Just like vectorizing unnecessarily makes matlab code looks bad, this is a case where devectorization makes julia code look bad.

Also, since I am "uglyfying" the code, I will throw in a few calls to Blas.gemm also for performance. I'll mess around with it a bit more and then commit.

ViralBShah commented 11 years ago

I am reopening this, because our expm is still orders of magnitude slower than Matlab. Our code is pure julia, and Matlab's code is pure matlab. The test scripts are in #1547 (running only 10 iterations instead of 100). Julia's expm is 3x-10x slower than Matlab's.

julia> load("test_expm.jl")
N = 1 elapsed time: 0.3190290927886963 seconds
N = 2 elapsed time: 0.16882896423339844 seconds
N = 1 elapsed time: 6.9141387939453125e-6 seconds
N = 2 elapsed time: 0.0001819133758544922 seconds
N = 3 elapsed time: 0.006791114807128906 seconds
N = 4 elapsed time: 0.00044608116149902344 seconds
N = 5 elapsed time: 0.0006268024444580078 seconds
N = 6 elapsed time: 0.0013430118560791016 seconds
N = 7 elapsed time: 0.0016360282897949219 seconds
N = 8 elapsed time: 0.024058818817138672 seconds
N = 16 elapsed time: 0.007166862487792969 seconds
N = 32 elapsed time: 0.02408003807067871 seconds
N = 64 elapsed time: 0.12331485748291016 seconds
N = 128 elapsed time: 0.4610929489135742 seconds
N = 160 elapsed time: 0.7158498764038086 seconds
N = 192 elapsed time: 1.0662789344787598 seconds
N = 256 elapsed time: 1.8436567783355713 seconds
N = 384 elapsed time: 4.277431011199951 seconds
N = 512 elapsed time: 7.933350086212158 seconds
>> test_expm
N = 1 elapsed time: 0.047057seconds
N = 2 elapsed time: 0.076391seconds
N = 1 elapsed time: 0.0015028seconds
N = 2 elapsed time: 0.0014384seconds
N = 3 elapsed time: 0.049408seconds
N = 4 elapsed time: 0.0015149seconds
N = 5 elapsed time: 0.0014085seconds
N = 6 elapsed time: 0.0013561seconds
N = 7 elapsed time: 0.0014252seconds
N = 8 elapsed time: 0.0014336seconds
N = 16 elapsed time: 0.0017586seconds
N = 32 elapsed time: 0.0032834seconds
N = 64 elapsed time: 0.0082592seconds
N = 128 elapsed time: 0.052828seconds
N = 160 elapsed time: 0.08219seconds
N = 192 elapsed time: 0.12715seconds
N = 256 elapsed time: 0.28671seconds
N = 384 elapsed time: 0.94108seconds
N = 512 elapsed time: 2.1999seconds
ViralBShah commented 11 years ago

@alanedelman Could you take a look at julia expm code and see if we can improve its performance?

The relevant code is here: https://github.com/JuliaLang/julia/blob/master/base/linalg_dense.jl#L249

alanedelman commented 11 years ago

Starting to investigate -- unfortunately on my windows vista machine the following crashes julia

A=rand(35,35); expm(A);

ViralBShah commented 11 years ago

Cc: @vtjnash @loladiro

ViralBShah commented 11 years ago

@alanedelman Can you try it out on julia.mit.edu?

vtjnash commented 11 years ago

odd, alan's example works for me on the latest release bad26f4bc0 on Windows 7

andreasnoack commented 11 years ago

I just managed to get Julia running on out old Windows Server 2003 and I can reproduce the Alan's error and I also get the output

0x085D2EE3 (0x0062F570 0x00000000 0x004FC778 0x10BD01C0), LAPACKE_csyr_work() +
0xC579E3 bytes(s)
0x07ACFCC6 (0x0062F570 0x00000000 0x00562BF8 0x10BD01C0), LAPACKE_csyr_work() +
0x1547C6 bytes(s)
0x07ACFCC6 (0x0062F570 0x00000000 0x005C9078 0x10BD01C0), LAPACKE_csyr_work() +
0x1547C6 bytes(s)
0x07ACFCC6 (0x0062F570 0x00000000 0x0062F4F8 0x10BD01C0), LAPACKE_csyr_work() +
0x1547C6 bytes(s)
0x07ACFCC6 (0x0062F570 0x00000000 0x00000000 0x10BD01C0), LAPACKE_csyr_work() +
0x1547C6 bytes(s)
0x07346996 (0x0062F680 0x02DE8BA3 0x00000023 0x03FDD2D8), dgetrf_() + 0x146 byte
0x03FDD308 (0x00000000 0x00000000 0x00000000 0x00000000) <unknown module>
ViralBShah commented 11 years ago

Should this be a separate issue filed for Windows?

ViralBShah commented 11 years ago

I have a faster expm that I am just committing.