Open DhairyaLGandhi opened 7 months ago
Seems okay to me. Is there a Zygote issue discussing why changing broadcast to map is required here? It might be worth referencing that in a code comment otherwise this could get changed back in future.
Beyond this PR we could think about adding a Zygote test if we want to make sure we don't break AD compat.
Zygote hasn't changed here, what is required is overloading Zygote.literal_indexed_iterate
. I was trying to avoid the dependency on Zygote. I also feel it would be better if the implementation of the function didn't change during AD or otherwise. That can make it harder to debug gradient issues.
I like the idea of adding tests here
Definitely agree about not depending on Zygote and the same implementation with/without AD. I'm just wondering why that overload is required at all, sounds like something that could be tracked/improved in Zygote?
Lgtm modulo adding a test. Would it make sense to make this part of AtomsBaseTesting to test also in downstream codes?
It could be an optional extra or emit a warning in AtomsBaseTesting, but I don't think we should make Zygote compat a required part of the interface.
Tests to check that the systems in AtomsBase are Zygote-compatible would be useful to avoid regression though, I guess taking on Zygote as a test dependency is fine.
I was working on some tests to add to this and am running into a missing adjoint issue...
box = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]u"m"
bcs = [Periodic(), Periodic(), DirichletZero()]
elements = [:C, :C]
atoms = [Atom(elements[i], positions[i]) for i in 1:2]
# distance between first two particles
function dist(sys::AbstractSystem)
sepvec = diff(position(sys))[1]
sqrt(dot(sepvec, sepvec))
end
gradient(0) do x
positions = [[0, 0, 0], [x, 0.5, 0.5]]u"m"
atoms = [Atom(elements[i], positions[i]) for i in 1:2]
flexible = FlexibleSystem(atoms, box, bcs)
dist(flexible)
end
And I get a super long stacktrace that starts with:
ERROR: Need an adjoint for constructor StaticArrays.SVector{3, Quantity{Float64, 𝐋, Unitful.FreeUnits{(m,), 𝐋, nothing}}}. Gradient is of type StaticArrays.SVector{3, Float64}
I found a whole chain of discussions across several PR's on various packages (1 -> 2 -> 3 -> 4 -> 5 -> 6), and I don't follow enough of the nitty-gritty details to know for sure if that last one will fix this or not when merged (it also seems like it might depend on the Julia version? I was doing this on 1.9.2), but hopefully @DhairyaLGandhi can lend some insight?
I get a different issue, related to mutation with position.(sys)
. I am on Julia 1.10.0 and the latest StaticArrays, ChainRules etc.
using AtomsBase, Zygote, Unitful, LinearAlgebra
box = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]u"m"
bcs = [Periodic(), Periodic(), DirichletZero()]
elements = [:C, :C]
function dist(sys::AbstractSystem)
sepvec = diff(position(sys))[1]
sqrt(dot(sepvec, sepvec))
end
gradient(0) do x
positions = [[0, 0, 0], [x, 0.5, 0.5]]u"m"
atoms = [Atom(elements[i], positions[i]) for i in 1:2]
flexible = FlexibleSystem(atoms, box, bcs)
dist(flexible)
end
ERROR: Mutating arrays is not supported -- called copyto!(Vector{Atom{3, Quantity{Float64, 𝐋, Unitful.FreeUnits{(m,), 𝐋, nothing}}, Quantity{Float64, 𝐋 𝐓^-1, Unitful.FreeUnits{(a₀, s^-1), 𝐋 𝐓^-1, nothing}}, Quantity{Float64, 𝐌, Unitful.FreeUnits{(u,), 𝐌, nothing}}}}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)
Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
https://fluxml.ai/Zygote.jl/latest/limitations
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] _throw_mutation_error(f::Function, args::Vector{Atom{3, Quantity{…}, Quantity{…}, Quantity{…}}})
@ Zygote ~/.julia/dev/Zygote/src/lib/array.jl:70
[3] (::Zygote.var"#543#544"{Vector{…}})(::Vector{@NamedTuple{…}})
@ Zygote ~/.julia/dev/Zygote/src/lib/array.jl:85
[4] (::Zygote.var"#2633#back#545"{Zygote.var"#543#544"{…}})(Δ::Vector{@NamedTuple{…}})
@ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
[5] _collect(c::Any, itr::Any, ::Base.EltypeUnknown, isz::Union{Base.HasLength, Base.HasShape})
@ Base ./array.jl:765 [inlined]
[6] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Vector{@NamedTuple{…}})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[7] collect(itr::Base.Generator{Base.Iterators.Zip{Tuple{Vector{…}, Vector{…}}}, Base.var"#4#5"{Zygote.var"#1366#1372"}})
@ Base ./array.jl:759 [inlined]
[8] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Vector{@NamedTuple{…}})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[9] broadcastable
@ ./broadcast.jl:743 [inlined]
[10] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Vector{@NamedTuple{…}})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[11] broadcasted
@ ./broadcast.jl:1339 [inlined]
[12] position
@ ~/.julia/dev/AtomsBase/src/interface.jl:139 [inlined]
[13] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Vector{StaticArraysCore.SVector{…}})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[14] dist
@ ./REPL[5]:2 [inlined]
[15] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[16] #1
@ ./REPL[6]:5 [inlined]
[17] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[18] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:45
[19] gradient(::Function, ::Int64, ::Vararg{Int64})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:97
[20] top-level scope
@ REPL[6]:1
Some type information was truncated. Use `show(err)` to see complete types.
So far, I've looked at build_graph
from https://github.com/Chemellia/AtomGraphs.jl/pull/11 as the test case for a code path to AD. I don't think unitful is fully supported to AD
I have always found it hard to get units to play well with AD, and don't use them when taking gradients in my own code.
There is https://github.com/SBuercklin/UnitfulChainRules.jl which may be useful.
Zygote has a property called
literal_indexed_iterate
which types with some iteration can implement to allow for cleaner accumulation of gradients when working with AD. However, this adds a dependency on Zygote, which might be costly for a base package.Package extensions also cannot be used since it would basically overwrite methods causing an amount of piracy. It is also disallowed as of Julia 1.10. This therefore is a simple way to still benefit from AD-able code gen while not having to introduce (any) complexity.