JuliaMolSim / AtomsBase.jl

A Julian abstract interface for atomic structures.
https://juliamolsim.github.io/AtomsBase.jl/
MIT License
81 stars 16 forks source link

Make some functions more AD friendly #91

Open DhairyaLGandhi opened 7 months ago

DhairyaLGandhi commented 7 months ago

Zygote has a property called literal_indexed_iterate which types with some iteration can implement to allow for cleaner accumulation of gradients when working with AD. However, this adds a dependency on Zygote, which might be costly for a base package.

Package extensions also cannot be used since it would basically overwrite methods causing an amount of piracy. It is also disallowed as of Julia 1.10. This therefore is a simple way to still benefit from AD-able code gen while not having to introduce (any) complexity.

jgreener64 commented 7 months ago

Seems okay to me. Is there a Zygote issue discussing why changing broadcast to map is required here? It might be worth referencing that in a code comment otherwise this could get changed back in future.

Beyond this PR we could think about adding a Zygote test if we want to make sure we don't break AD compat.

DhairyaLGandhi commented 7 months ago

Zygote hasn't changed here, what is required is overloading Zygote.literal_indexed_iterate. I was trying to avoid the dependency on Zygote. I also feel it would be better if the implementation of the function didn't change during AD or otherwise. That can make it harder to debug gradient issues.

DhairyaLGandhi commented 7 months ago

I like the idea of adding tests here

jgreener64 commented 7 months ago

Definitely agree about not depending on Zygote and the same implementation with/without AD. I'm just wondering why that overload is required at all, sounds like something that could be tracked/improved in Zygote?

mfherbst commented 7 months ago

Lgtm modulo adding a test. Would it make sense to make this part of AtomsBaseTesting to test also in downstream codes?

jgreener64 commented 7 months ago

It could be an optional extra or emit a warning in AtomsBaseTesting, but I don't think we should make Zygote compat a required part of the interface.

Tests to check that the systems in AtomsBase are Zygote-compatible would be useful to avoid regression though, I guess taking on Zygote as a test dependency is fine.

rkurchin commented 5 months ago

I was working on some tests to add to this and am running into a missing adjoint issue...

box = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]u"m"
bcs = [Periodic(), Periodic(), DirichletZero()]
elements = [:C, :C]
atoms = [Atom(elements[i], positions[i]) for i in 1:2]

# distance between first two particles
function dist(sys::AbstractSystem)
    sepvec = diff(position(sys))[1]
    sqrt(dot(sepvec, sepvec))
end

gradient(0) do x
    positions = [[0, 0, 0], [x, 0.5, 0.5]]u"m"
    atoms = [Atom(elements[i], positions[i]) for i in 1:2]
    flexible = FlexibleSystem(atoms, box, bcs)
    dist(flexible)
end

And I get a super long stacktrace that starts with:

ERROR: Need an adjoint for constructor StaticArrays.SVector{3, Quantity{Float64, 𝐋, Unitful.FreeUnits{(m,), 𝐋, nothing}}}. Gradient is of type StaticArrays.SVector{3, Float64}

I found a whole chain of discussions across several PR's on various packages (1 -> 2 -> 3 -> 4 -> 5 -> 6), and I don't follow enough of the nitty-gritty details to know for sure if that last one will fix this or not when merged (it also seems like it might depend on the Julia version? I was doing this on 1.9.2), but hopefully @DhairyaLGandhi can lend some insight?

jgreener64 commented 5 months ago

I get a different issue, related to mutation with position.(sys). I am on Julia 1.10.0 and the latest StaticArrays, ChainRules etc.

using AtomsBase, Zygote, Unitful, LinearAlgebra

box = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]u"m"
bcs = [Periodic(), Periodic(), DirichletZero()]
elements = [:C, :C]

function dist(sys::AbstractSystem)
    sepvec = diff(position(sys))[1]
    sqrt(dot(sepvec, sepvec))
end

gradient(0) do x
    positions = [[0, 0, 0], [x, 0.5, 0.5]]u"m"
    atoms = [Atom(elements[i], positions[i]) for i in 1:2]
    flexible = FlexibleSystem(atoms, box, bcs)
    dist(flexible)
end
ERROR: Mutating arrays is not supported -- called copyto!(Vector{Atom{3, Quantity{Float64, 𝐋, Unitful.FreeUnits{(m,), 𝐋, nothing}}, Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] _throw_mutation_error(f::Function, args::Vector{Atom{3, Quantity{…}, Quantity{…}, Quantity{…}}})
    @ Zygote ~/.julia/dev/Zygote/src/lib/array.jl:70
  [3] (::Zygote.var"#543#544"{Vector{…}})(::Vector{@NamedTuple{…}})
    @ Zygote ~/.julia/dev/Zygote/src/lib/array.jl:85
  [4] (::Zygote.var"#2633#back#545"{Zygote.var"#543#544"{…}})(Δ::Vector{@NamedTuple{…}})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
  [5] _collect(c::Any, itr::Any, ::Base.EltypeUnknown, isz::Union{Base.HasLength, Base.HasShape})
    @ Base ./array.jl:765 [inlined]
  [6] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Vector{@NamedTuple{…}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
  [7] collect(itr::Base.Generator{Base.Iterators.Zip{Tuple{Vector{…}, Vector{…}}}, Base.var"#4#5"{Zygote.var"#1366#1372"}})
    @ Base ./array.jl:759 [inlined]
  [8] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Vector{@NamedTuple{…}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
  [9] broadcastable
    @ ./broadcast.jl:743 [inlined]
 [10] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Vector{@NamedTuple{…}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [11] broadcasted
    @ ./broadcast.jl:1339 [inlined]
 [12] position
    @ ~/.julia/dev/AtomsBase/src/interface.jl:139 [inlined]
 [13] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Vector{StaticArraysCore.SVector{…}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [14] dist
    @ ./REPL[5]:2 [inlined]
 [15] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [16] #1
    @ ./REPL[6]:5 [inlined]
 [17] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [18] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:45
 [19] gradient(::Function, ::Int64, ::Vararg{Int64})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:97
 [20] top-level scope
    @ REPL[6]:1
Some type information was truncated. Use `show(err)` to see complete types.
DhairyaLGandhi commented 5 months ago

So far, I've looked at build_graph from https://github.com/Chemellia/AtomGraphs.jl/pull/11 as the test case for a code path to AD. I don't think unitful is fully supported to AD

jgreener64 commented 5 months ago

I have always found it hard to get units to play well with AD, and don't use them when taking gradients in my own code.

There is https://github.com/SBuercklin/UnitfulChainRules.jl which may be useful.