Open jgreener64 opened 2 years ago
Can reproduce the errors shown. Note BTW that Molly only seems to work (for me) on 1.7, not on Julia 1.8. (But maybe it's an Apple M1 problem.)
Without the opt-out, @debug
prints some things. And the shortcut in rrule_via_ad
does not seem to be involved in calling these rules:
(jl_GvojyS) pkg> st
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_GvojyS/Project.toml`
[082447d4] ChainRules v1.44.2
[aa0f7f06] Molly v0.13.0
[e88e6eb3] Zygote v0.6.43
julia> ENV["JULIA_DEBUG"] = ChainRules;
# This isn't in fact called:
julia> @eval Zygote function ChainRulesCore.rrule_via_ad(config::ZygoteRuleConfig, f_args...; kwargs...)
# first check whether there is an `rrule` which handles this directly
direct = rrule(config, f_args...; kwargs...)
f = f_args[1]
direct === nothing || (@info "rrule shortcut" f; return direct)
# create a closure to work around _pullback not accepting kwargs
# but avoid creating a closure unnecessarily (pullbacks of closures do not infer)
y, pb = if !isempty(kwargs)
kwf() = first(f_args)(Base.tail(f_args)...; kwargs...)
_y, _pb = _pullback(config.context, kwf)
_y, Δ -> first(_pb(Δ)).f_args # `first` should be `only`
else
_pullback(config.context, f_args...)
end
ad_pullback(Δ) = zygote2differential(pb(wrap_chainrules_output(Δ)), f_args)
return y, ad_pullback
end;
# Note that the T == Bool path is called many times, no @info here
julia> @eval Zygote @adjoint function broadcasted(::AbstractArrayStyle, f::F, args...) where {F}
T = Broadcast.combine_eltypes(f, args)
# Avoid generic broadcasting in two easy cases:
if T == Bool
return (f.(args...), _ -> nothing)
elseif T <: Real && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args) && !isderiving()
@info "Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> broadcast_forward" f
return broadcast_forward(f, args...)
end
len = inclen(args)
@info "Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks" f
y∂b = _broadcast((x...) -> _pullback(__context__, f, x...), args...)
y = map(first, y∂b)
function ∇broadcasted(ȳ)
dxs_zip = map(((_, pb), ȳ₁) -> pb(ȳ₁), y∂b, ȳ)
dxs = ntuple(len) do i
collapse_nothings(map(StaticGetter{i}(), dxs_zip))
end
(nothing, accum_sum(dxs[1]), map(unbroadcast, args, Base.tail(dxs))...)
end
return y, ∇broadcasted
end
# Random easy test
julia> gradient(xs -> sum((x -> sin(x)).(xs)), [1,2,3]/4)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> broadcast_forward
└ f = #6 (generic function with 1 method)
([0.9689124217106447, 0.8775825618903728, 0.7316888688738209],)
# MWE from above
julia> @time gradient(loss, params_dic)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_atom (generic function with 1 method)
┌ Debug: split broadcasting generic
│ f = inject_interaction (generic function with 7 methods)
│ N = 2
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:126
┌ Debug: split broadcasting generic
│ f = inject_interaction_list (generic function with 4 methods)
│ N = 3
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:126
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_interaction (generic function with 7 methods)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_interaction (generic function with 7 methods)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_interaction (generic function with 7 methods)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_interaction (generic function with 7 methods)
┌ Debug: split broadcasting generic
│ f = inject_interaction (generic function with 7 methods)
│ N = 3
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:126
ERROR: MethodError: no method matching +(::Dict{Any, Any}, ::Dict{Any, Any})
(Error as above.)
With the opt-out, it's the second broadcast above, with inject_interaction
, which fails:
julia> ChainRulesCore.@opt_out rrule(cfg::Zygote.ZygoteRuleConfig, ::typeof(Broadcast.broadcasted), ::Broadcast.BroadcastStyle, f::F, args::Vararg{Any, N}) where {F, N}
julia> gradient(xs -> sum((x -> sin(x)).(xs)), [1,2,3]/4)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> broadcast_forward
└ f = #18 (generic function with 1 method)
([0.9689124217106447, 0.8775825618903728, 0.7316888688738209],)
julia> gradient((xs, y) -> sum((x -> sin(x/y)).(xs)), [1,2,3]/4, 5/6)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = #22 (generic function with 1 method)
([1.146403786950727, 0.9904027378916139, 0.7459319619247974], -1.609501544552504)
julia> @time gradient(loss, params_dic)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_atom (generic function with 1 method)
ERROR: Need an adjoint for constructor Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(Molly.inject_interaction), Tuple{Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Dict{String, Float64}}, Tuple{System{3, true, Float64, false, Vector{Atom{Float64, Float64, Float64, Float64}}, Vector{AtomData}, Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{InteractionList2Atoms{Vector{HarmonicBond{Float64, Float64}}}, InteractionList3Atoms{Vector{HarmonicAngle{Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}}, Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Vector{SVector{3, Float64}}, Vector{SVector{3, Float64}}, CubicBoundary{Float64}, DistanceVecNeighborFinder{Float64, BitMatrix, Matrix{Int64}}, Tuple{}, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}, Float64}}}}. Gradient is of type Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}}
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] (::Zygote.Jnew{Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(Molly.inject_interaction), Tuple{Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Dict{String, Float64}}, Tuple{System{3, true, Float64, false, Vector{Atom{Float64, Float64, Float64, Float64}}, Vector{AtomData}, Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{InteractionList2Atoms{Vector{HarmonicBond{Float64, Float64}}}, InteractionList3Atoms{Vector{HarmonicAngle{Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}}, Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Vector{SVector{3, Float64}}, Vector{SVector{3, Float64}}, CubicBoundary{Float64}, DistanceVecNeighborFinder{Float64, BitMatrix, Matrix{Int64}}, Tuple{}, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}, Float64}}}}, Nothing, false})(Δ::Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}})
@ Zygote ~/.julia/packages/Zygote/D7j8v/src/lib/lib.jl:326
So plausibly it's a failure of the opt-out mechanism? Where above it uses the rrule
, here it instead seems to not find the @adjoint
rule at all, perhaps?
Now I see. The CR rule accepts any BroadcastStyle, to handle tuples too, while the Zygote one restricts to AbstractArrayStyle. The cases where the CR rule is called all have Broadcast.Style{Tuple}()
:
julia> @eval ChainRules function rrule(cfg::RCR, ::typeof(broadcasted), style::BroadcastStyle, f::F, args::Vararg{Any,N}) where {F,N}
@debug "called the rrule!" f style
T = Broadcast.combine_eltypes(f, args)
if T === Bool # TODO use nondifftype here
# 1: Trivial case: non-differentiable output, e.g. `x .> 0`
@debug("split broadcasting trivial", f, T)
bc_trivial_back(_) = (TRI_NO..., ntuple(Returns(ZeroTangent()), length(args))...)
return f.(args...), bc_trivial_back
elseif T <: Number && may_bc_derivatives(T, f, args...)
# 2: Fast path: use arguments & result to find derivatives.
return split_bc_derivatives(f, args...)
elseif T <: Number && may_bc_forwards(cfg, f, args...)
# 3: Future path: use `frule_via_ad`?
return split_bc_forwards(cfg, f, args...)
else
# 4: Slow path: collect all the pullbacks & apply them later.
return split_bc_pullbacks(cfg, f, args...)
end
end
rrule (generic function with 1065 methods)
julia> @time gradient(loss, params_dic)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_atom (generic function with 1 method)
┌ Debug: called the rrule!
│ f = inject_interaction (generic function with 7 methods)
│ style = Base.Broadcast.Style{Tuple}()
└ @ ChainRules REPL[37]:2
┌ Debug: split broadcasting generic
│ f = inject_interaction (generic function with 7 methods)
│ N = 2
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:126
┌ Debug: called the rrule!
│ f = inject_interaction_list (generic function with 4 methods)
│ style = Base.Broadcast.Style{Tuple}()
└ @ ChainRules REPL[37]:2
┌ Debug: split broadcasting generic
│ f = inject_interaction_list (generic function with 4 methods)
│ N = 3
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:126
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_interaction (generic function with 7 methods)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_interaction (generic function with 7 methods)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_interaction (generic function with 7 methods)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_interaction (generic function with 7 methods)
┌ Debug: called the rrule!
│ f = inject_interaction (generic function with 7 methods)
│ style = Base.Broadcast.Style{Tuple}()
└ @ ChainRules REPL[37]:2
┌ Debug: split broadcasting generic
│ f = inject_interaction (generic function with 7 methods)
│ N = 3
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:126
ERROR: MethodError: no method matching +(::Dict{Any, Any}, ::Dict{Any, Any})
This is not solved by changing the rrule
's signature to match the @adjoint
rule, and reject Broadcast.Style{Tuple}
. With https://github.com/JuliaDiff/ChainRules.jl/commit/6e383c10f8dfc0731be64a51bc3a33a6c7d21f5b you get the same error as with the @opt_out
:
julia> @time gradient(loss, params_dic)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└ f = inject_atom (generic function with 1 method)
ERROR: Need an adjoint for constructor Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(Molly.inject_interaction), Tuple{Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Dict{String, Float64}}, Tuple{System{3, true, Float64, false, Vector{Atom{Float64, Float64, Float64, Float64}}, Vector{AtomData}, Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{InteractionList2Atoms{Vector{HarmonicBond{Float64, Float64}}}, InteractionList3Atoms{Vector{HarmonicAngle{Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}}, Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Vector{SVector{3, Float64}}, Vector{SVector{3, Float64}}, CubicBoundary{Float64}, DistanceVecNeighborFinder{Float64, BitMatrix, Matrix{Int64}}, Tuple{}, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}, Float64}}}}. Gradient is of type Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}}
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] (::Zygote.Jnew{Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(Molly.inject_interaction), Tuple{Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Dict{String, Float64}}, Tuple{System{3, true, Float64, false, Vector{Atom{Float64, Float64, Float64, Float64}}, Vector{AtomData}, Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{InteractionList2Atoms{Vector{HarmonicBond{Float64, Float64}}}, InteractionList3Atoms{Vector{HarmonicAngle{Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}}, Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Vector{SVector{3, Float64}}, Vector{SVector{3, Float64}}, CubicBoundary{Float64}, DistanceVecNeighborFinder{Float64, BitMatrix, Matrix{Int64}}, Tuple{}, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}, Float64}}}}, Nothing, false})(Δ::Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}})
This is a zygote bug. I wish i i could transfer this issue there
As discussed on Slack there is an issue with dictionaries appearing in gradients. The following is as minimum an example as I could make.
This requires Molly master, Zygote master, ChainRules 1.44.2 and I am using Julia 1.7.2. The file
ala5.pdb
should be put in the current directory and is pasted below.The
ala5.pdb
file:On ChainRules up to 1.42.0 this worked, on 1.43.0-1.44.1 it errors with a different error fixed by https://github.com/JuliaDiff/ChainRules.jl/pull/661, and on 1.44.2 it errors as follows:
Adding
ChainRulesCore.@opt_out rrule(cfg::Zygote.ZygoteRuleConfig, ::typeof(Broadcast.broadcasted), ::Broadcast.BroadcastStyle, f::F, args::Vararg{Any, N}) where {F, N}
to Molly as suggested by @mcabbott gives a different error:Commenting out either the
"inter_LJ_weight_14" => 0.5,
or"inter_CO_weight_14" => 0.5,
lines makes it work, presumably because no dictionaries have to be added in the case of one gradient.