JuliaLang / julia

The Julia Programming Language
https://julialang.org/
MIT License
45.54k stars 5.47k forks source link

Improve julia's expm performance #1543

Closed ViralBShah closed 11 years ago

ViralBShah commented 11 years ago

Improve expm performance by directly calling BLAS/LAPACK routines in expm and using the fewest number of temporaries.

https://groups.google.com/d/msg/julia-dev/QdaRhy3DoWE/KtIWdsNdrKcJ

ViralBShah commented 11 years ago

This myexpm implementation tries to reduce the number of temporaries through devectorization. It is about 30% faster.

function myexpm!{T}(A::StridedMatrix{T})
    m, n = size(A)
    if m != n error("myexpm!: Matrix A must be square") end
    if m < 2 return exp(A) end
    ilo, ihi, scale = LAPACK.gebal!('B', A)    # modifies A
    nA   = norm(A, 1)
    I    = convert(Array{T,2}, eye(n))
    ## For sufficiently small nA, use lower order Padé-Approximations
    if (nA <= 2.1)
        if nA > 0.95
            C = [17643225600.,8821612800.,2075673600.,302702400.,
                    30270240.,   2162160.,    110880.,     3960.,
                          90.,         1.]
        elseif nA > 0.25
            C = [17297280.,8648640.,1995840.,277200.,
                    25200.,   1512.,     56.,     1.]
        elseif nA > 0.015
            C = [30240.,15120.,3360.,
                   420.,   30.,   1.]
        else
            C = [120.,60.,12.,1.]
        end
        A2 = A * A
        P  = copy(I)
        U  = C[2] * P
        V  = C[1] * P
        for k in 1:(div(size(C, 1), 2) - 1)
            k2 = 2 * k
            P *= A2
            #U += C[k2 + 2] * P
            #V += C[k2 + 1] * P
            Ck21 = C[k2 + 1]
            Ck22 = C[k2 + 2]
            for i=1:length(P)
                U[i] += Ck22 * P[i]
                V[i] += Ck21 * P[i]
            end
        end
        U  = A * U
        X  = (V - U)\(V + U)
    else
        s  = log2(nA/5.4)               # power of 2 later reversed by squaring
        if s > 0
            si = iceil(s)
            A /= 2^si
        end
        CC = [64764752532480000.,32382376266240000.,7771770303897600.,
               1187353796428800.,  129060195264000.,  10559470521600.,
                   670442572800.,      33522128640.,      1323241920.,
                       40840800.,           960960.,           16380.,
                            182.,                1.]
        A2 = A * A
        A4 = A2 * A2
        A6 = A2 * A4
#         U  = A * (A6 * (CC[14]*A6 + CC[12]*A4 + CC[10]*A2) +
#                   CC[8]*A6 + CC[6]*A4 + CC[4]*A2 + CC[2]*I)
#         V  = A6 * (CC[13]*A6 + CC[11]*A4 + CC[9]*A2) +
#                   CC[7]*A6 + CC[5]*A4 + CC[3]*A2 + CC[1]*I
        P1 = zeros(T, n, n)
        P2 = zeros(T, n, n)
        P3 = zeros(T, n, n)
        P4 = zeros(T, n, n)
        CC14 = CC[14]; CC12 = CC[12]; CC10 = CC[10]
        CC8 = CC[8];   CC6 = CC[6];   CC4 = CC[4];   CC2 = CC[2];   
        CC13 = CC[13]; CC11 = CC[11]; CC9 = CC[9]   
        CC7 = CC[7];   CC5 = CC[5];   CC3 = CC[3];   CC1 = CC[1]
        for i=1:length(I)
            P1[i] += CC14*A6[i] + CC12*A4[i] + CC10*A2[i]
            P2[i] += CC8*A6[i] + CC6*A4[i] + CC4*A2[i] + CC2*I[i]
            P3[i] += CC13*A6[i] + CC11*A4[i] + CC9*A2[i]
            P4[i] += CC7*A6[i] + CC5*A4[i] + CC3*A2[i] + CC1*I[i]
        end
        U = A * (A6*P1 + P2)
        V = A6*P3 + P4

        X  = (V-U)\(V+U)

        if s > 0            # squaring to reverse dividing by power of 2
            for t in 1:si X *= X end
        end
    end
                                        # Undo the balancing
    doscale = false                     # check if rescaling is needed
    for i = ilo:ihi
        if scale[i] != 1.
            doscale = true
            break
        end
    end
    if doscale
        for j = ilo:ihi
            scj = scale[j]
            if scj != 1.                # is this overkill?
                for i = ilo:ihi
                    X[i,j] *= scale[i]/scj
                end
            else
                for i = ilo:ihi
                    X[i,j] *= scale[i]
                end
            end
        end
    end
    if ilo > 1       # apply lower permutations in reverse order
        for j in (ilo-1):1:-1 rcswap!(j, int(scale[j]), X) end
    end
    if ihi < n       # apply upper permutations in forward order
        for j in (ihi+1):n    rcswap!(j, int(scale[j]), X) end
    end
    convert(Matrix{T}, X)
end

## Swap rows j and jp and columns j and jp in X
function rcswap!{T<:Number}(j::Int, jp::Int, X::StridedMatrix{T})
    for k in 1:size(X, 2)
        tmp     = X[k,j]
        X[k,j]  = X[k,jp]
        X[k,jp] = tmp
        tmp     = X[j,k]
        X[j,k]  = X[jp,k]
        X[jp,k] = tmp
    end
end

# Matrix exponential
myexpm{T<:Union(Float32,Float64,Complex64,Complex128)}(A::StridedMatrix{T}) = myexpm!(copy(A))
myexpm{T<:Integer}(A::StridedMatrix{T}) = myexpm!(float(A))
JeffBezanson commented 11 years ago

Any reason not to commit this?

ViralBShah commented 11 years ago

I am a bit hesitant because it makes the code a bit too ugly! Just like vectorizing unnecessarily makes matlab code looks bad, this is a case where devectorization makes julia code look bad.

Also, since I am "uglyfying" the code, I will throw in a few calls to Blas.gemm also for performance. I'll mess around with it a bit more and then commit.

ViralBShah commented 11 years ago

I am reopening this, because our expm is still orders of magnitude slower than Matlab. Our code is pure julia, and Matlab's code is pure matlab. The test scripts are in #1547 (running only 10 iterations instead of 100). Julia's expm is 3x-10x slower than Matlab's.

julia> load("test_expm.jl")
N = 1 elapsed time: 0.3190290927886963 seconds
N = 2 elapsed time: 0.16882896423339844 seconds
N = 1 elapsed time: 6.9141387939453125e-6 seconds
N = 2 elapsed time: 0.0001819133758544922 seconds
N = 3 elapsed time: 0.006791114807128906 seconds
N = 4 elapsed time: 0.00044608116149902344 seconds
N = 5 elapsed time: 0.0006268024444580078 seconds
N = 6 elapsed time: 0.0013430118560791016 seconds
N = 7 elapsed time: 0.0016360282897949219 seconds
N = 8 elapsed time: 0.024058818817138672 seconds
N = 16 elapsed time: 0.007166862487792969 seconds
N = 32 elapsed time: 0.02408003807067871 seconds
N = 64 elapsed time: 0.12331485748291016 seconds
N = 128 elapsed time: 0.4610929489135742 seconds
N = 160 elapsed time: 0.7158498764038086 seconds
N = 192 elapsed time: 1.0662789344787598 seconds
N = 256 elapsed time: 1.8436567783355713 seconds
N = 384 elapsed time: 4.277431011199951 seconds
N = 512 elapsed time: 7.933350086212158 seconds
>> test_expm
N = 1 elapsed time: 0.047057seconds
N = 2 elapsed time: 0.076391seconds
N = 1 elapsed time: 0.0015028seconds
N = 2 elapsed time: 0.0014384seconds
N = 3 elapsed time: 0.049408seconds
N = 4 elapsed time: 0.0015149seconds
N = 5 elapsed time: 0.0014085seconds
N = 6 elapsed time: 0.0013561seconds
N = 7 elapsed time: 0.0014252seconds
N = 8 elapsed time: 0.0014336seconds
N = 16 elapsed time: 0.0017586seconds
N = 32 elapsed time: 0.0032834seconds
N = 64 elapsed time: 0.0082592seconds
N = 128 elapsed time: 0.052828seconds
N = 160 elapsed time: 0.08219seconds
N = 192 elapsed time: 0.12715seconds
N = 256 elapsed time: 0.28671seconds
N = 384 elapsed time: 0.94108seconds
N = 512 elapsed time: 2.1999seconds
ViralBShah commented 11 years ago

@alanedelman Could you take a look at julia expm code and see if we can improve its performance?

The relevant code is here: https://github.com/JuliaLang/julia/blob/master/base/linalg_dense.jl#L249

alanedelman commented 11 years ago

Starting to investigate -- unfortunately on my windows vista machine the following crashes julia

A=rand(35,35); expm(A);

ViralBShah commented 11 years ago

Cc: @vtjnash @loladiro

ViralBShah commented 11 years ago

@alanedelman Can you try it out on julia.mit.edu?

vtjnash commented 11 years ago

odd, alan's example works for me on the latest release bad26f4bc0 on Windows 7

andreasnoack commented 11 years ago

I just managed to get Julia running on out old Windows Server 2003 and I can reproduce the Alan's error and I also get the output

0x085D2EE3 (0x0062F570 0x00000000 0x004FC778 0x10BD01C0), LAPACKE_csyr_work() +
0xC579E3 bytes(s)
0x07ACFCC6 (0x0062F570 0x00000000 0x00562BF8 0x10BD01C0), LAPACKE_csyr_work() +
0x1547C6 bytes(s)
0x07ACFCC6 (0x0062F570 0x00000000 0x005C9078 0x10BD01C0), LAPACKE_csyr_work() +
0x1547C6 bytes(s)
0x07ACFCC6 (0x0062F570 0x00000000 0x0062F4F8 0x10BD01C0), LAPACKE_csyr_work() +
0x1547C6 bytes(s)
0x07ACFCC6 (0x0062F570 0x00000000 0x00000000 0x10BD01C0), LAPACKE_csyr_work() +
0x1547C6 bytes(s)
0x07346996 (0x0062F680 0x02DE8BA3 0x00000023 0x03FDD2D8), dgetrf_() + 0x146 byte
s(s)
0x03FDD308 (0x00000000 0x00000000 0x00000000 0x00000000) <unknown module>
ViralBShah commented 11 years ago

Should this be a separate issue filed for Windows?

ViralBShah commented 11 years ago

I have a faster expm that I am just committing.