GiggleLiu / NiLang.jl

A differential eDSL that can run faster than light and go back to the past.
https://giggleliu.github.io/NiLang.jl/dev
Apache License 2.0
250 stars 16 forks source link

Some trivial questions #78

Closed sumuzhao closed 2 years ago

sumuzhao commented 3 years ago

Hey GiggleLiu,

Your work is fantastic. I'm new to Julia and NiLang. I'm still a bit frustrated on refactoring my funciton into a reversible one, though reading your tutorial and documentations. There might be still some simple cases which are not covered, or perhaps I did not notice them.

  1. How to reset an external variable to zero inside a loop? And I find the external variable is "stateless" inside the loop, that is, it will be zero each time entering the loop regardless of being assigned some value before. If I really need a global variable to record some states, how can I do?
    # just an example, do not care about its logic
    @i funtion f1(y::T, x::AbstractVector{T}) where T
    @routine begin
        @zeros T anc
        for i = 1:length(x)
            if x % 2 == 0
                anc += x
            else
                ~@zeros T anc        # here throw deallocation error, but do not know how to do
            end
        end
    end
    ~@routine
    end
  2. Are common mathematic functions supported?
    min(1, 20, 30.2)
    abs(-1.2)
  3. Are round() supported?
    x = round(1.2345, digits=2)
  4. How to deal with if...elseif...elseif... with multiple conditions? Is it reversible?
    if x > 0 && y > 0            # here how can we use if (condition, ~) ?
    do something
    elseif x > 0 && y < 0
    do something
    elseif x < 0 && y > 0
    do something
    elseif x < 0 && y < 0
    do something
    end
  5. Is calculation in if conditions supported?
    if x > 0 && y > x * 2 + z / 10
    do something
    end

Here is a toy example.

@i function f1(y::T, x::Vector{T}, pos::T) where T

    @routine begin
    for i = 1:length(x)
        @safe println(i, ", check pos, ", pos)
        if (pos >= 0 && pos <= 3, ~)
            @routine begin
                @zeros T anc
                anc += i
            end
            pos += anc
            @safe println("cond1, ", pos)
            ~@routine
        else
            @routine begin
                @zeros T anc
                anc -= i
            end
            pos += anc
            @safe println("cond1, ", pos)
            ~@routine
        end
        @safe println("recheck, ", pos)
    end
    end
    y += pos
    ~@routine
end

y, x, pos = 0.0, [i * 1.0 for i = 1:10], 0.0
y, x, pos = f1(y, x, pos)

Then the output is,

1, check pos, 0.0
cond1, 1.0
recheck, 1.0
2, check pos, 1.0
cond1, 3.0
recheck, 3.0
3, check pos, 3.0
cond1, 6.0
deallocate fail `##branch#10337 → pos >= 0 && pos <= 3`
InvertibilityError("deallocate fail (primitive): true ≂̸ false")

Stacktrace:
 [1] f1(y::Float64, x::Vector{Float64}, pos::Float64)
   @ Main ./In[432]:16
 [2] top-level scope
   @ In[434]:1
 [3] eval
   @ ./boot.jl:360 [inlined]
 [4] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
   @ Base ./loading.jl:1116

How to deal with this issue?

Thanks a lot!

GiggleLiu commented 3 years ago

Short answers to your question,

  1. You can not deallocate a variable in a different scope to the one it is allocated. It does not make sense to deallocate inside a for-loop, one might risk deallocating multiple times.

2, 3. You can always use += to make a statement reversible,

y += min(a, b)
y += abs(a)
y += round(a; digists=2)  # not differentiable

The only problem is some are not differentiable.

  1. It is ok to write if-elseif statements, however not recommended. You need to make sure do something does not change the condition (the easiest way is to store the condition variable in a new variable before running the condition expression). Otherwise, you need a postcondition.
boolean += expr
if boolean
    do something
else
    do something
end

You might want to preallocate the boolean variable.

  1. Yes.

WRT your toy example, can you please paste its irrversible version? I can help you make it reversible.

sumuzhao commented 3 years ago

Thanks for your reply! Question 2. 3. & 5. are clearly clarified but I'm still not sure for 1. & 4. Here is the toy example in its irreversible version. The logic is random but it covers some points that I do not know how to deal with.

function f1(y::T, x::Vector{T}, pos::T) where T

    s = 0     # this is an external variable outside the loop to store some summations. 
    for i = 1:length(x)
        println(i, ", check pos, ", pos)
        anc = 0
        if pos >= 0 && pos <= 3        # chained condiitons
            anc += x[i]
            pos += anc
            println("cond1 (pos >= 0 && pos <= 3), ", pos)
        elseif pos > 3 && pos <= 6
            anc -= x[i]
            pos += anc
            println("cond2 (pos > 3 && pos <= 6), ", pos)
        elseif pos > 6
            anc -= x[i] / 2
            pos += anc
            println("cond3 (pos > 6), ", pos)
            if pos > 2                # nested if condition
                s += pos
            end
        else
            anc += x[i] / 2
            pos += anc
            println("cond4 (pos < 0), ", pos)
            if pos < 2
                s += pos
            end
        end
    end
    y += s
end

y, x, pos = 0.0, [i * 1.0 for i = 1:10], 0.0
f1(y, x, pos)

Then the output is,

1, check pos, 0.0
cond1 (pos >= 0 && pos <= 3), 1.0
2, check pos, 1.0
cond1 (pos >= 0 && pos <= 3), 3.0
3, check pos, 3.0
cond1 (pos >= 0 && pos <= 3), 6.0
4, check pos, 6.0
cond2 (pos > 3 && pos <= 6), 2.0
5, check pos, 2.0
cond1 (pos >= 0 && pos <= 3), 7.0
6, check pos, 7.0
cond3 (pos > 6), 4.0
7, check pos, 4.0
cond2 (pos > 3 && pos <= 6), -3.0
8, check pos, -3.0
cond4 (pos < 0), 1.0
9, check pos, 1.0
cond1 (pos >= 0 && pos <= 3), 10.0
10, check pos, 10.0
cond3 (pos > 6), 5.0
10.0
GiggleLiu commented 3 years ago

This is a working version,

using NiLang

@i function f1(s::T, y::T, x::Vector{T}, pos::T, branch_keeper::AbstractVector{Int}) where T
    for i = 1:length(x)
        @safe println(i, ", check pos, ", pos)
        if (pos >= 0 && pos <= 3, branch_keeper[i]==1)        # chained condiitons
            branch_keeper[i] += 1
            @routine begin
                anc ← zero(T)
                anc += x[i]
            end
            pos += anc
            ~@routine
            @safe println("cond1 (pos >= 0 && pos <= 3), ", pos)
        elseif (pos > 3 && pos <= 6, branch_keeper[i]==2)
            branch_keeper[i] += 2
            @routine begin
                anc ← zero(T)
                anc -= x[i]
            end
            pos += anc
            ~@routine
            @safe println("cond2 (pos > 3 && pos <= 6), ", pos)
        elseif (pos > 6, branch_keeper[i]==3)
            branch_keeper[i] += 3
            @routine begin
                anc ← zero(T)
                anc -= x[i] / 2
            end
            pos += anc
            ~@routine
            @safe println("cond3 (pos > 6), ", pos)
            if pos > 2                # nested if condition
                s += pos
            end
        else
            @routine begin
                anc ← zero(T)
                anc += x[i] / 2
            end
            pos += anc
            ~@routine
            @safe println("cond4 (pos < 0), ", pos)
            if pos < 2
                s += pos
            end
        end
    end
    y += s
end

y, x, pos, branch_keeper = 0.0, [i * 1.0 for i = 1:10], 0.0, zeros(Int, 10)
NiLang.check_inv(f1, (0.0, y, x, pos, branch_keeper))  # check reversibility (to ensure round off error does not ruin reversibility)
f1(0.0, y, x, pos, branch_keeper)

Hope this example clarifies your questions.

sumuzhao commented 3 years ago

The branch_keeper is the magic! But if I add some codes like below, then it will give me deallocation error again.

......
if (pos >= 0 && pos <= 3, branch_keeper[i]==1)        # chained condiitons
    branch_keeper[i] += 1
    @routine begin
        anc ← zero(T)
        tmp ← zero(T)
        tmp += pos - 2   # HERE! Quite wired, only pos - 2 will raise the error. x[i] - 2, y - 2, x - 2 are all okay. 
        anc += x[i]
    end
    pos += anc   # I guess the problem is due to this? pos is assigned? 
    ~@routine
    @safe println("cond1 (pos >= 0 && pos <= 3), ", pos)
......

Then the error is,

1, check pos, 0.0
deallocate fail `tmp → zero(T)`
InvertibilityError("deallocate fail (floating point numbers): -1.1234 ≂̸ 0.0")

Stacktrace:
 [1] f1(s::Float64, y::Float64, x::Vector{Float64}, pos::Float64, branch_keeper::Vector{Int64})
   @ Main ./In[622]:8
 [2] top-level scope
   @ In[623]:3
 [3] eval
   @ ./boot.jl:360 [inlined]
 [4] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
   @ Base ./loading.jl:1116

For this problem, I find using a tmp array like the branch keeper array works. But I do not know why?...

BTW, after successfully refactoring the above funtion in a reversible one, I tried to calculate the gradient but sadly met some errors. I realize the problem is that the funciton itself is not continuous, that is, not differentiable, due to these if-else conditions. So 'reversible' is not 'differentiable'. It's a non-trivial task to make a non-continuous funciton differentiable...... 😭

GiggleLiu commented 3 years ago

Because your new code is not reversible. ~@routine computes the inverse pass.

julia> NiLangCore.precom_ex(Main, :(if (pos >= 0 && pos <= 3, branch_keeper[i]==1)        # chained condiitons
           branch_keeper[i] += 1
           @routine begin
               anc ← zero(T)
               tmp ← zero(T)
               tmp += pos - 2   # HERE! Quite wired, only pos - 2 will raise the error. x[i] - 2, y - 2, x - 2 are all okay. 
               anc += x[i]
           end
           pos += anc   # I guess the problem is due to this? pos is assigned? 
           ~@routine
           @safe println("cond1 (pos >= 0 && pos <= 3), ", pos)
           end)) |> NiLangCore.rmlines
:(if (pos >= 0 && pos <= 3, branch_keeper[i] == 1)
      branch_keeper[i] += identity(1)
      begin
          anc ← zero(T)
          tmp ← zero(T)
          tmp += pos - 2
          anc += identity(x[i])
      end
      pos += identity(anc)
      begin
          anc -= identity(x[i])
          tmp -= pos - 2
          tmp → zero(T)
          anc → zero(T)
      end
      @safe println("cond1 (pos >= 0 && pos <= 3), ", pos)
  end)

Since you changed pos, the uncomputing failed to clear tmp.