GiggleLiu / NiLang.jl

A differential eDSL that can run faster than light and go back to the past.
https://giggleliu.github.io/NiLang.jl/dev
Apache License 2.0
250 stars 16 forks source link

broadcasting: Array/scalar + scalar #63

Open johnnychen94 opened 3 years ago

johnnychen94 commented 3 years ago

I thought something like this should work:

@i function add1!(out!)
    out! .+= 1
end

x = zeros(Int, 4, 4)
add1_new!(x)
# ERROR: MethodError: no method matching unzipped_broadcast(::PlusEq{typeof(identity)}, ::Matrix{Int64}, ::Int64)

A workaround is to manually write the loop:

@i function add1_new!(out!)
    for i in 1:length(out!)
        out![i] += 1
        # out![i] .+= 1 # doesn't work eigher
    end
end

add1_new!(x)

Similarity, 1 .+ 1 works in plain Julia but out![i] .+= 1 doesn't in NiLang (see the comment above)

Is this a bug or an expected behavior?

GiggleLiu commented 3 years ago

Sorry for the late reply. It is delibrately forbidden in NiLang so that the gradients can be calculated correctly. A shared read variable can become a shared write variable in the reverse mode AD pass. In broadcasting, a scalar is effectively shared.

This is called the shared-read-write problem: https://giggleliu.github.io/NiLang.jl/dev/notebooks/documentation.html#Multi-threading_and_CUDA

johnnychen94 commented 3 years ago

I see, that makes sense. By reading the link you shared I now have two more questions:

@i function shared_read(loss::Real, y::Vector, x::Real, z::Vector)
    @safe @assert length(z) == length(y)
    @threads for i=1:length(y)
        y[i] += x * z[i]
    end
    for i=1:length(y)
        loss += y[i]
    end
end
  1. what if @threads is removed here? My understanding is: it's still a shared read but no race condition involved, so it will not cause any problem.
  2. how can we properly work around this issue(e.g., this specific shared_read function) in NiLang? For example, in plain Julia we add locks or copy the data. Does NiLang provide a similar strategy?
GiggleLiu commented 3 years ago

it's still a shared read but no race condition involved, so it will not cause any problem.

You are right. Actually, it is possible to give a better support to broadcasting by rewritting the broadcasting in NiLang's grammar.

For example, in plain Julia we add locks or copy the data. Does NiLang provide a similar strategy?

No, it is super hard to add a lock, consider some one randomly access elements in a vector, you do not know when the shared read happens.

johnnychen94 commented 3 years ago

Thanks.

I'll keep this issue opened because I feel we could have more examples explaining these constraints; the shared read/write issue is quite generic a problem.

I'll close it when I come up with a docs PR.

GiggleLiu commented 3 years ago

Sounds good