Closed sumuzhao closed 2 years ago
Short answers to your question,
2, 3. You can always use +=
to make a statement reversible,
y += min(a, b)
y += abs(a)
y += round(a; digists=2) # not differentiable
The only problem is some are not differentiable.
if-elseif
statements, however not recommended. You need to make sure do something
does not change the condition (the easiest way is to store the condition variable in a new variable before running the condition expression). Otherwise, you need a postcondition.boolean += expr
if boolean
do something
else
do something
end
You might want to preallocate the boolean
variable.
WRT your toy example, can you please paste its irrversible version? I can help you make it reversible.
Thanks for your reply! Question 2. 3. & 5. are clearly clarified but I'm still not sure for 1. & 4. Here is the toy example in its irreversible version. The logic is random but it covers some points that I do not know how to deal with.
function f1(y::T, x::Vector{T}, pos::T) where T
s = 0 # this is an external variable outside the loop to store some summations.
for i = 1:length(x)
println(i, ", check pos, ", pos)
anc = 0
if pos >= 0 && pos <= 3 # chained condiitons
anc += x[i]
pos += anc
println("cond1 (pos >= 0 && pos <= 3), ", pos)
elseif pos > 3 && pos <= 6
anc -= x[i]
pos += anc
println("cond2 (pos > 3 && pos <= 6), ", pos)
elseif pos > 6
anc -= x[i] / 2
pos += anc
println("cond3 (pos > 6), ", pos)
if pos > 2 # nested if condition
s += pos
end
else
anc += x[i] / 2
pos += anc
println("cond4 (pos < 0), ", pos)
if pos < 2
s += pos
end
end
end
y += s
end
y, x, pos = 0.0, [i * 1.0 for i = 1:10], 0.0
f1(y, x, pos)
Then the output is,
1, check pos, 0.0
cond1 (pos >= 0 && pos <= 3), 1.0
2, check pos, 1.0
cond1 (pos >= 0 && pos <= 3), 3.0
3, check pos, 3.0
cond1 (pos >= 0 && pos <= 3), 6.0
4, check pos, 6.0
cond2 (pos > 3 && pos <= 6), 2.0
5, check pos, 2.0
cond1 (pos >= 0 && pos <= 3), 7.0
6, check pos, 7.0
cond3 (pos > 6), 4.0
7, check pos, 4.0
cond2 (pos > 3 && pos <= 6), -3.0
8, check pos, -3.0
cond4 (pos < 0), 1.0
9, check pos, 1.0
cond1 (pos >= 0 && pos <= 3), 10.0
10, check pos, 10.0
cond3 (pos > 6), 5.0
10.0
This is a working version,
using NiLang
@i function f1(s::T, y::T, x::Vector{T}, pos::T, branch_keeper::AbstractVector{Int}) where T
for i = 1:length(x)
@safe println(i, ", check pos, ", pos)
if (pos >= 0 && pos <= 3, branch_keeper[i]==1) # chained condiitons
branch_keeper[i] += 1
@routine begin
anc ← zero(T)
anc += x[i]
end
pos += anc
~@routine
@safe println("cond1 (pos >= 0 && pos <= 3), ", pos)
elseif (pos > 3 && pos <= 6, branch_keeper[i]==2)
branch_keeper[i] += 2
@routine begin
anc ← zero(T)
anc -= x[i]
end
pos += anc
~@routine
@safe println("cond2 (pos > 3 && pos <= 6), ", pos)
elseif (pos > 6, branch_keeper[i]==3)
branch_keeper[i] += 3
@routine begin
anc ← zero(T)
anc -= x[i] / 2
end
pos += anc
~@routine
@safe println("cond3 (pos > 6), ", pos)
if pos > 2 # nested if condition
s += pos
end
else
@routine begin
anc ← zero(T)
anc += x[i] / 2
end
pos += anc
~@routine
@safe println("cond4 (pos < 0), ", pos)
if pos < 2
s += pos
end
end
end
y += s
end
y, x, pos, branch_keeper = 0.0, [i * 1.0 for i = 1:10], 0.0, zeros(Int, 10)
NiLang.check_inv(f1, (0.0, y, x, pos, branch_keeper)) # check reversibility (to ensure round off error does not ruin reversibility)
f1(0.0, y, x, pos, branch_keeper)
Hope this example clarifies your questions.
The branch_keeper is the magic! But if I add some codes like below, then it will give me deallocation error again.
......
if (pos >= 0 && pos <= 3, branch_keeper[i]==1) # chained condiitons
branch_keeper[i] += 1
@routine begin
anc ← zero(T)
tmp ← zero(T)
tmp += pos - 2 # HERE! Quite wired, only pos - 2 will raise the error. x[i] - 2, y - 2, x - 2 are all okay.
anc += x[i]
end
pos += anc # I guess the problem is due to this? pos is assigned?
~@routine
@safe println("cond1 (pos >= 0 && pos <= 3), ", pos)
......
Then the error is,
1, check pos, 0.0
deallocate fail `tmp → zero(T)`
InvertibilityError("deallocate fail (floating point numbers): -1.1234 ≂̸ 0.0")
Stacktrace:
[1] f1(s::Float64, y::Float64, x::Vector{Float64}, pos::Float64, branch_keeper::Vector{Int64})
@ Main ./In[622]:8
[2] top-level scope
@ In[623]:3
[3] eval
@ ./boot.jl:360 [inlined]
[4] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
@ Base ./loading.jl:1116
For this problem, I find using a tmp array like the branch keeper array works. But I do not know why?...
BTW, after successfully refactoring the above funtion in a reversible one, I tried to calculate the gradient but sadly met some errors. I realize the problem is that the funciton itself is not continuous, that is, not differentiable, due to these if-else conditions. So 'reversible' is not 'differentiable'. It's a non-trivial task to make a non-continuous funciton differentiable...... 😭
Because your new code is not reversible. ~@routine
computes the inverse pass.
julia> NiLangCore.precom_ex(Main, :(if (pos >= 0 && pos <= 3, branch_keeper[i]==1) # chained condiitons
branch_keeper[i] += 1
@routine begin
anc ← zero(T)
tmp ← zero(T)
tmp += pos - 2 # HERE! Quite wired, only pos - 2 will raise the error. x[i] - 2, y - 2, x - 2 are all okay.
anc += x[i]
end
pos += anc # I guess the problem is due to this? pos is assigned?
~@routine
@safe println("cond1 (pos >= 0 && pos <= 3), ", pos)
end)) |> NiLangCore.rmlines
:(if (pos >= 0 && pos <= 3, branch_keeper[i] == 1)
branch_keeper[i] += identity(1)
begin
anc ← zero(T)
tmp ← zero(T)
tmp += pos - 2
anc += identity(x[i])
end
pos += identity(anc)
begin
anc -= identity(x[i])
tmp -= pos - 2
tmp → zero(T)
anc → zero(T)
end
@safe println("cond1 (pos >= 0 && pos <= 3), ", pos)
end)
Since you changed pos
, the uncomputing failed to clear tmp
.
Hey GiggleLiu,
Your work is fantastic. I'm new to Julia and NiLang. I'm still a bit frustrated on refactoring my funciton into a reversible one, though reading your tutorial and documentations. There might be still some simple cases which are not covered, or perhaps I did not notice them.
Here is a toy example.
Then the output is,
How to deal with this issue?
Thanks a lot!