Closed ViralBShah closed 11 years ago
This myexpm
implementation tries to reduce the number of temporaries through devectorization. It is about 30% faster.
function myexpm!{T}(A::StridedMatrix{T})
m, n = size(A)
if m != n error("myexpm!: Matrix A must be square") end
if m < 2 return exp(A) end
ilo, ihi, scale = LAPACK.gebal!('B', A) # modifies A
nA = norm(A, 1)
I = convert(Array{T,2}, eye(n))
## For sufficiently small nA, use lower order Padé-Approximations
if (nA <= 2.1)
if nA > 0.95
C = [17643225600.,8821612800.,2075673600.,302702400.,
30270240., 2162160., 110880., 3960.,
90., 1.]
elseif nA > 0.25
C = [17297280.,8648640.,1995840.,277200.,
25200., 1512., 56., 1.]
elseif nA > 0.015
C = [30240.,15120.,3360.,
420., 30., 1.]
else
C = [120.,60.,12.,1.]
end
A2 = A * A
P = copy(I)
U = C[2] * P
V = C[1] * P
for k in 1:(div(size(C, 1), 2) - 1)
k2 = 2 * k
P *= A2
#U += C[k2 + 2] * P
#V += C[k2 + 1] * P
Ck21 = C[k2 + 1]
Ck22 = C[k2 + 2]
for i=1:length(P)
U[i] += Ck22 * P[i]
V[i] += Ck21 * P[i]
end
end
U = A * U
X = (V - U)\(V + U)
else
s = log2(nA/5.4) # power of 2 later reversed by squaring
if s > 0
si = iceil(s)
A /= 2^si
end
CC = [64764752532480000.,32382376266240000.,7771770303897600.,
1187353796428800., 129060195264000., 10559470521600.,
670442572800., 33522128640., 1323241920.,
40840800., 960960., 16380.,
182., 1.]
A2 = A * A
A4 = A2 * A2
A6 = A2 * A4
# U = A * (A6 * (CC[14]*A6 + CC[12]*A4 + CC[10]*A2) +
# CC[8]*A6 + CC[6]*A4 + CC[4]*A2 + CC[2]*I)
# V = A6 * (CC[13]*A6 + CC[11]*A4 + CC[9]*A2) +
# CC[7]*A6 + CC[5]*A4 + CC[3]*A2 + CC[1]*I
P1 = zeros(T, n, n)
P2 = zeros(T, n, n)
P3 = zeros(T, n, n)
P4 = zeros(T, n, n)
CC14 = CC[14]; CC12 = CC[12]; CC10 = CC[10]
CC8 = CC[8]; CC6 = CC[6]; CC4 = CC[4]; CC2 = CC[2];
CC13 = CC[13]; CC11 = CC[11]; CC9 = CC[9]
CC7 = CC[7]; CC5 = CC[5]; CC3 = CC[3]; CC1 = CC[1]
for i=1:length(I)
P1[i] += CC14*A6[i] + CC12*A4[i] + CC10*A2[i]
P2[i] += CC8*A6[i] + CC6*A4[i] + CC4*A2[i] + CC2*I[i]
P3[i] += CC13*A6[i] + CC11*A4[i] + CC9*A2[i]
P4[i] += CC7*A6[i] + CC5*A4[i] + CC3*A2[i] + CC1*I[i]
end
U = A * (A6*P1 + P2)
V = A6*P3 + P4
X = (V-U)\(V+U)
if s > 0 # squaring to reverse dividing by power of 2
for t in 1:si X *= X end
end
end
# Undo the balancing
doscale = false # check if rescaling is needed
for i = ilo:ihi
if scale[i] != 1.
doscale = true
break
end
end
if doscale
for j = ilo:ihi
scj = scale[j]
if scj != 1. # is this overkill?
for i = ilo:ihi
X[i,j] *= scale[i]/scj
end
else
for i = ilo:ihi
X[i,j] *= scale[i]
end
end
end
end
if ilo > 1 # apply lower permutations in reverse order
for j in (ilo-1):1:-1 rcswap!(j, int(scale[j]), X) end
end
if ihi < n # apply upper permutations in forward order
for j in (ihi+1):n rcswap!(j, int(scale[j]), X) end
end
convert(Matrix{T}, X)
end
## Swap rows j and jp and columns j and jp in X
function rcswap!{T<:Number}(j::Int, jp::Int, X::StridedMatrix{T})
for k in 1:size(X, 2)
tmp = X[k,j]
X[k,j] = X[k,jp]
X[k,jp] = tmp
tmp = X[j,k]
X[j,k] = X[jp,k]
X[jp,k] = tmp
end
end
# Matrix exponential
myexpm{T<:Union(Float32,Float64,Complex64,Complex128)}(A::StridedMatrix{T}) = myexpm!(copy(A))
myexpm{T<:Integer}(A::StridedMatrix{T}) = myexpm!(float(A))
Any reason not to commit this?
I am a bit hesitant because it makes the code a bit too ugly! Just like vectorizing unnecessarily makes matlab code looks bad, this is a case where devectorization makes julia code look bad.
Also, since I am "uglyfying" the code, I will throw in a few calls to Blas.gemm also for performance. I'll mess around with it a bit more and then commit.
I am reopening this, because our expm is still orders of magnitude slower than Matlab. Our code is pure julia, and Matlab's code is pure matlab. The test scripts are in #1547 (running only 10 iterations instead of 100). Julia's expm is 3x-10x slower than Matlab's.
julia> load("test_expm.jl")
N = 1 elapsed time: 0.3190290927886963 seconds
N = 2 elapsed time: 0.16882896423339844 seconds
N = 1 elapsed time: 6.9141387939453125e-6 seconds
N = 2 elapsed time: 0.0001819133758544922 seconds
N = 3 elapsed time: 0.006791114807128906 seconds
N = 4 elapsed time: 0.00044608116149902344 seconds
N = 5 elapsed time: 0.0006268024444580078 seconds
N = 6 elapsed time: 0.0013430118560791016 seconds
N = 7 elapsed time: 0.0016360282897949219 seconds
N = 8 elapsed time: 0.024058818817138672 seconds
N = 16 elapsed time: 0.007166862487792969 seconds
N = 32 elapsed time: 0.02408003807067871 seconds
N = 64 elapsed time: 0.12331485748291016 seconds
N = 128 elapsed time: 0.4610929489135742 seconds
N = 160 elapsed time: 0.7158498764038086 seconds
N = 192 elapsed time: 1.0662789344787598 seconds
N = 256 elapsed time: 1.8436567783355713 seconds
N = 384 elapsed time: 4.277431011199951 seconds
N = 512 elapsed time: 7.933350086212158 seconds
>> test_expm
N = 1 elapsed time: 0.047057seconds
N = 2 elapsed time: 0.076391seconds
N = 1 elapsed time: 0.0015028seconds
N = 2 elapsed time: 0.0014384seconds
N = 3 elapsed time: 0.049408seconds
N = 4 elapsed time: 0.0015149seconds
N = 5 elapsed time: 0.0014085seconds
N = 6 elapsed time: 0.0013561seconds
N = 7 elapsed time: 0.0014252seconds
N = 8 elapsed time: 0.0014336seconds
N = 16 elapsed time: 0.0017586seconds
N = 32 elapsed time: 0.0032834seconds
N = 64 elapsed time: 0.0082592seconds
N = 128 elapsed time: 0.052828seconds
N = 160 elapsed time: 0.08219seconds
N = 192 elapsed time: 0.12715seconds
N = 256 elapsed time: 0.28671seconds
N = 384 elapsed time: 0.94108seconds
N = 512 elapsed time: 2.1999seconds
@alanedelman Could you take a look at julia expm
code and see if we can improve its performance?
The relevant code is here: https://github.com/JuliaLang/julia/blob/master/base/linalg_dense.jl#L249
Starting to investigate -- unfortunately on my windows vista machine the following crashes julia
A=rand(35,35); expm(A);
Cc: @vtjnash @loladiro
@alanedelman Can you try it out on julia.mit.edu?
odd, alan's example works for me on the latest release bad26f4bc0 on Windows 7
I just managed to get Julia running on out old Windows Server 2003 and I can reproduce the Alan's error and I also get the output
0x085D2EE3 (0x0062F570 0x00000000 0x004FC778 0x10BD01C0), LAPACKE_csyr_work() +
0xC579E3 bytes(s)
0x07ACFCC6 (0x0062F570 0x00000000 0x00562BF8 0x10BD01C0), LAPACKE_csyr_work() +
0x1547C6 bytes(s)
0x07ACFCC6 (0x0062F570 0x00000000 0x005C9078 0x10BD01C0), LAPACKE_csyr_work() +
0x1547C6 bytes(s)
0x07ACFCC6 (0x0062F570 0x00000000 0x0062F4F8 0x10BD01C0), LAPACKE_csyr_work() +
0x1547C6 bytes(s)
0x07ACFCC6 (0x0062F570 0x00000000 0x00000000 0x10BD01C0), LAPACKE_csyr_work() +
0x1547C6 bytes(s)
0x07346996 (0x0062F680 0x02DE8BA3 0x00000023 0x03FDD2D8), dgetrf_() + 0x146 byte
s(s)
0x03FDD308 (0x00000000 0x00000000 0x00000000 0x00000000) <unknown module>
Should this be a separate issue filed for Windows?
I have a faster expm that I am just committing.
Improve expm performance by directly calling BLAS/LAPACK routines in expm and using the fewest number of temporaries.
https://groups.google.com/d/msg/julia-dev/QdaRhy3DoWE/KtIWdsNdrKcJ