GiggleLiu / NiLang.jl

A differential eDSL that can run faster than light and go back to the past.
https://giggleliu.github.io/NiLang.jl/dev
Apache License 2.0
250 stars 16 forks source link

InvertibilityError in reversible matrix-vector dot product #61

Open guixinliu opened 3 years ago

guixinliu commented 3 years ago

Hi, l wanted to write my reversible matrix-vector dot product by followings the examples at

https://github.com/GiggleLiu/NiLang.jl/blob/master/src/stdlib/sparse.jl

But I got some error. The details are as follows:

I rewrote this function https://github.com/JuliaLang/julia/blob/b773bebcdb1eccaf3efee0bfe564ad552c0bcea7/stdlib/SparseArrays/src/linalg.jl#L331.

Here are my code :

@i function idot(x::Vector, A::SparseMatrixCSC, y::Vector{T}) where {T}
    @routine @invcheckoff begin
        @safe Base.require_one_based_indexing(x,y) # maybe irreversible ?
        (m, n) ← size(A)
    end
    @safe (length(x) == m && n == length(y)) || throw(DimensionMismatch())
    T1 ← promote_type(eltype(x), eltype(A), eltype(y))
    # if (iszero(m) || iszero(n) , ~)
    #     dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y)))
    # end
    r ← zero(T1)
    rvals ← getrowval(A)
    nzvals ← getnzval(A)
    @invcheckoff @inbounds for col in 1:n
        ycol ← y[col]
        if (!iszero(ycol), ~)
            temp ← zero(T1)
            for k in nzrange(A, col)
                anc1 ← zero(x[rvals[k]])
                anc1 += adjoint(x[rvals[k]])
                temp += anc1 * nzvals[k]
            end
        end
        r += temp * ycol
    end
    ~@routine
end

my test code is:

using BenchmarkTools
a = sprand(1000, 1000, 0.01);
x = randn(1000);
y = randn(1000);

@benchmark idot(x, a, y)

It didn't work. I got the error:

ERROR: LoadError: InvertibilityError("can not deallocate because -40.07385632952906 ≂̸ 0.0")
Stacktrace:
 [1] deanc at C:\Users\myWork\.julia\packages\NiLangCore\6yO1L\src\vars.jl:8 [inlined]
 [2] idot(::Array{Float64,1}, ::SparseMatrixCSC{Float64,Int64}, ::Array{Float64,1}) at e:\Workspace\NiLang-learn\SparseArrays_idot.jl:53
 [3] top-level scope at e:\Workspace\NiLang-learn\test\idot_test.jl:7
in expression starting at e:\Workspace\NiLang-learn\test\idot_test.jl:7

I couldn't find out why it occurred. I would be very appreciated if anyone can give some help.

GiggleLiu commented 3 years ago

Hi, I would suggest your upgrading NiLang to v0.9. The lastest NiLang will give you more detailed error messages. Actually, your have chosen a tough function to implement in NiLang. The problem in your code is you didn't deallocate ancillas, which makes the compiler throw InvertibilityError. You have to either store them somewhere or uncompute it to zeros.

Here are two example approaches to write this function reversiblly.

No time overhead, but having nnz(y) * length(x) space overheads.

using NiLang, SparseArrays

struct SpdotRes{T}
    r::T
    ds::Vector{T}
    xinds::Vector{Int}
    yinds::Vector{Int}
    branch_keeper::Vector{UInt8}
    branch_keeper_i::Int
end

@i function idot(res::SpdotRes, x::SparseVector, A::SparseArrays.AbstractSparseMatrixCSC, y::SparseVector)
    m, n ← size(A)
    @safe length(x) == m && n == length(y) || throw(DimensionMismatch())
    @invcheckoff @inbounds if !(iszero(m) || iszero(n))
        for j = 1:length(y.nzind)
            @routine begin
                A_ptr_lo ← A.colptr[y.nzind[j]]
                A_ptr_hi ← A.colptr[y.nzind[j]+1] - 1
            end
            if A_ptr_lo <= A_ptr_hi
                acc_spdot(res.ds[j], 1, res.xinds[j], @const(length(x.nzind)), x.nzind, x.nzval,
                                                A_ptr_lo, res.yinds[j], A_ptr_hi, A.rowval, A.nzval,
                                                res.branch_keeper, res.branch_keeper_i)
                res.r += res.ds[j] * y.nzval[j]
            end
            ~@routine
        end
    end
    m, n → size(A)
end

@i function acc_spdot(s::T, xj_first::Int, xj::Int, xj_last::Int, xnzind, xnzval,
                yj_first::Int, yj::Int, yj_last::Int, ynzind, ynzval, branch_keeper, branch_keeper_i) where T
    # dot product between ranges of non-zeros,
    xj += xj_first
    yj += yj_first
    @invcheckoff @inbounds @from xj==xj_first && yj==yj_first while xj <= xj_last && yj <= yj_last
        @routine begin
            ix ← xnzind[xj]
            iy ← ynzind[yj]
        end
        bk ← zero(UInt8)
        if ix == iy
            bk ⊻= 0x1
        elseif ix < iy
            bk ⊻= 0x2
        else
            bk ⊻= 0x3
        end
        ~@routine
        if bk==1
            s += xnzval[xj]' * ynzval[yj]
            xj += 1
            yj += 1
        elseif bk==2
            xj += 1
        else
            yj += 1
        end
        branch_keeper_i += 1
        SWAP(branch_keeper[branch_keeper_i], bk)
        bk → zero(UInt8)
    end
end

function auto_allocate(::typeof(idot), x::SparseVector{T1}, A::SparseArrays.AbstractSparseMatrixCSC{T2}, 
                                                y::SparseVector{T3}) where {T1, T2, T3}
    T = promote_type(T1,T2,T3)
    N = length(y.nzind)
    branch_keeper = zeros(UInt8, size(A,1)*length(y.nzval)) # the maximum number of branch
    return SpdotRes(zero(T), zeros(T, N), zeros(Int, N), zeros(Int, N), branch_keeper, 0)
end

@fieldview NiLang.value(res::SpdotRes) = res.r

using Test
@testset "sparse dot" begin
    vx = sprand(1000, 0.2)
    vy = sprand(1000, 0.2)
    A = sprand(1000, 1000, 0.01)
    res = SparseArrays.dot(vx, A, vy)
    res2 = idot(auto_allocate(idot, vx, A, vy), vx, A, vy)[1]
    @test res ≈ value(res2)
    @test check_inv(idot, (auto_allocate(idot, vx, A, vy), vx, A, vy); verbose=true)

    vx = sprand(100, 0.1)
    vy = sprand(100, 0.1)
    A = sprand(100, 100, 0.1)
    @i function loss(out, res, vx, A, vy)
        idot(res, vx, A, vy)
        out += res.r
    end
    ngrad = Vector(NiLang.AD.ng(loss, (0.0, auto_allocate(idot, vx, A, vy), copy(vx), A, vy), 3; iloss=1))
    grads = NiLang.AD.gradient(loss, (0.0, auto_allocate(idot, vx, A, vy), vx, A, vy); iloss=1)
    ngrad[setdiff(1:100, vx.nzind)] .= 0
    @test grads[3] ≈ ngrad
end

No memory overhead, but 2x slower

using NiLang, SparseArrays

@i function idot(r, x::SparseVector, A::SparseArrays.AbstractSparseMatrixCSC, y::SparseVector)
    m, n ← size(A)
    @safe length(x) == m && n == length(y) || throw(DimensionMismatch())
    @invcheckoff @inbounds if !(iszero(m) || iszero(n))
        branch_keeper ← zeros(UInt8, size(A, 1))
        for j = 1:length(y.nzind)
            @routine begin
                A_ptr_lo ← A.colptr[y.nzind[j]]
                A_ptr_hi ← A.colptr[y.nzind[j]+1] - 1
            end
            if A_ptr_lo <= A_ptr_hi
                @routine begin
                    @zeros Int branch_keeper_i xi yi
                    di ← zero(r)
                    acc_spdot(di, 1, xi, @const(length(x.nzind)), x.nzind, x.nzval,
                                          A_ptr_lo, yi, A_ptr_hi, A.rowval, A.nzval, branch_keeper, branch_keeper_i)
                end
                r += di * y.nzval[j]
                ~@routine
            end
            ~@routine
        end
        branch_keeper → zeros(UInt8, size(A, 1))
    end
    m, n → size(A)
end

using Test
@testset "sparse dot" begin
    vx = sprand(1000, 0.2)
    vy = sprand(1000, 0.2)
    A = sprand(1000, 1000, 0.01)
    res = SparseArrays.dot(vx, A, vy)
    res2 = idot(0.0, vx, A, vy)[1]
    @test res ≈ res2
    @test check_inv(idot, (0.0, vx, A, vy); verbose=true)

    vx = sprand(100, 0.1)
    vy = sprand(100, 0.1)
    A = sprand(100, 100, 0.1)
    ngrad = Vector(NiLang.AD.ng(idot, (0.0, copy(vx), A, vy), 2; iloss=1))
    grads = NiLang.AD.gradient(idot, (0.0, vx, A, vy); iloss=1)
    ngrad[setdiff(1:100, vx.nzind)] .= 0
    @test grads[2] ≈ ngrad
end

Notes

@routine

Here, we used the compute-copy-uncompute trick to uncompute a variable to zero.

@routine begin
    # allocate some ancillas with zero content
    # compute something...
end
# do something with the computed results (do not pollute variables used in uncomputing)...
~@routine   # uncompute the ancillas to zeros, and deallocate

Since uncomputing costs time, it gives a 2x slow down to the program.

While loop

@from condition1 while condition2
end

It means from condition1, compute the while loops with condition2. Which means the condition1 is true before computing the first loop, but becomes false after computing the first loop. Its inverse is

@from !condition2 while !condition1
end

There are many other details in these implementations, feel free to ask more.