kailaix / ADCME.jl

Automatic Differentiation Library for Computational and Mathematical Engineering
https://kailaix.github.io/ADCME.jl/latest/
MIT License
286 stars 57 forks source link

Efficient algorithm for solving tridiagonal matrix linear equations #56

Open GiggleLiu opened 4 years ago

GiggleLiu commented 4 years ago

In the first example in the README, the linear solver can be faster

"""
### References
* https://en.wikipedia.org/wiki/Tridiagonal_matrix_algorithm
"""
function trisolve!(A::AbstractVector, B::AbstractVector, C::AbstractVector, 
           D::AbstractVector, X::AbstractVector)
    N = length(X)
    B = copy(B)
    D = copy(D)
    @inbounds for i = 2:N
        W = A[i-1] / B[i - 1]
        B[i] = B[i] - W * C[i - 1]
        D[i] = D[i] - W * D[i - 1]
    end
    @inbounds X[N] = D[N] / B[N]
    @inbounds for i = N-1:-1:1
        X[i] = (D[i] - C[i] * X[i + 1]) / B[i]
    end
    return X
end

When A, Band C are uniform, we have

function trisolve!(a::T, b::T, c::T, D::AbstractVector, X::AbstractVector) where T
    N = length(X)
    D = copy(D)
    B = zeros(T, N)
    @inbounds B[1] = b
    @inbounds for i = 2:N
        w = a / B[i-1]
        B[i] = b - w * c
        D[i] = D[i] - w * D[i - 1]
    end
    @inbounds X[N] = D[N] / B[N]
    @inbounds for i = N-1:-1:1
        X[i] = (D[i] - c * X[i + 1]) / B[i]
    end
    return X
end

function myloss(b::T; n=101) where T
    h = 1/(n-1)
    x = LinRange(0,1,n)[2:end-1]
    f = @. T(4*(2 + x - x^2))

    u = trisolve!(-b/h^2, 2b/h^2+1, -b/h^2, f, zeros(T, n-2))
    ue = u[div(n+1,2)] # extract values at x=0.5

    return (ue-1.0)^2
end

myloss(10.0)

# the most efficient method to differentiate this program is forwarddiff
using ForwardDiff
myloss(ForwardDiff.Dual(10.0, 1.0))

The NiLang version is

using NiLang, NiLang.AD

@i function i_trisolve!(a::T, b::T, c::T, D!::AbstractVector, X!::AbstractVector, B!::AbstractVector) where T
    @invcheckoff @inbounds begin
    B![1] += b
    for i = 2:length(X!)
        @routine begin
            w ← zero(T)
            w += a / B![i-1]
        end
        B![i] += b
        B![i] -= w * c
        D![i] -= w * D![i - 1]
        ~@routine
    end
    X![end] += D![end] / B![end]
    for i = length(X!)-1:-1:1
        @routine begin
            anc ← zero(T)
            anc += D![i]
            anc -= c * X![i + 1]
        end
        X![i] += anc / B![i]
        ~@routine
    end
    end
end

@i function i_myloss!(loss::T, f!::AbstractVector{T}, u!::AbstractVector{T}, 
              b_cache!::AbstractVector{T}, b::T) where T
    @invcheckoff @inbounds begin
    n ← length(f!) + 2
    h ← zero(T)
    h += 1 / (n-1)
    for i=1:length(f!)
        @routine begin
            @zeros T xi anc
            xi += i * h
            anc += 2 + xi
            anc -= xi ^ 2
        end
        f![i] += 4 * anc
        ~@routine
    end

    @routine begin
        @zeros T h2 factor_a factor_b factor_c
        h2 += h^2
        factor_a -= b / h2
        factor_c += factor_a
        factor_b -= 2 * factor_a
        factor_b += 1
    end
    i_trisolve!(factor_a, factor_b, factor_c, f!, u!, b_cache!)
    ~@routine
    @routine begin
        ue ← zero(T)
        ue += u![div(n+1,2)] # extract values at x=0.5
        ue -= 1
    end
    loss += ue^2
    ~@routine
    h -= 1 / (n-1)
    end
end

n = 101
i_myloss!(0.0, zeros(n-2), zeros(n-2), zeros(n-2), 10.0)
Grad(i_myloss!)(Val(1), 0.0, zeros(n-2), zeros(n-2), zeros(n-2), 10.0)
GiggleLiu commented 4 years ago

Benchmarks

julia> @benchmark myloss(ForwardDiff.Dual(10.0, 1.0))
BenchmarkTools.Trial: 
  memory estimate:  7.06 KiB
  allocs estimate:  4
  --------------
  minimum time:     2.605 μs (0.00% GC)
  median time:      2.790 μs (0.00% GC)
  mean time:        3.356 μs (7.09% GC)
  maximum time:     190.302 μs (96.25% GC)
  --------------
  samples:          10000
  evals/sample:     9

julia> @benchmark Grad(i_myloss!)(Val(1), 0.0, zeros($n-2), zeros($n-2), zeros($n-2), 10.0)
BenchmarkTools.Trial: 
  memory estimate:  7.92 KiB
  allocs estimate:  6
  --------------
  minimum time:     7.122 μs (0.00% GC)
  median time:      7.266 μs (0.00% GC)
  mean time:        8.183 μs (3.35% GC)
  maximum time:     339.283 μs (95.11% GC)
  --------------
  samples:          10000
  evals/sample:     6

In the single parameter case, forwarddiff is faster because reversediff traverse the program twice.