JuliaDiff / ChainRules.jl

forward and reverse mode automatic differentiation primitives for Julia Base + StdLibs
Other
435 stars 89 forks source link

Dict gradients leading to addition error after broadcasting change #662

Open jgreener64 opened 2 years ago

jgreener64 commented 2 years ago

As discussed on Slack there is an issue with dictionaries appearing in gradients. The following is as minimum an example as I could make.

This requires Molly master, Zygote master, ChainRules 1.44.2 and I am using Julia 1.7.2. The file ala5.pdb should be put in the current directory and is pasted below.

using Molly, Zygote

data_dir = joinpath(dirname(pathof(Molly)), "..", "data")
ff = OpenMMForceField(
    joinpath(data_dir, "force_fields", "ff99SBildn.xml"),
    joinpath(data_dir, "force_fields", "his.xml");
    units=false,
)

sys = System(
    "ala5.pdb",
    ff;
    boundary=CubicBoundary(500.0, 500.0, 500.0),
    units=false,
    gpu_diff_safe=true,
    implicit_solvent="gbn2",
)

starting_coords = copy(sys.coords)
sim = Langevin(dt=0.001, temperature=300.0, friction=1.0)

params_dic = Dict(
    "inter_LJ_weight_14" => 0.5,
    "inter_CO_weight_14" => 0.5,
)

function loss(params_dic)
    atoms, pis, sis, gis = inject_gradients(sys, params_dic)
    sys2 = System(
        atoms=atoms,
        pairwise_inters=pis,
        specific_inter_lists=sis,
        general_inters=gis,
        coords=copy(starting_coords),
        boundary=sys.boundary,
        neighbor_finder=sys.neighbor_finder,
        force_units=NoUnits,
        energy_units=NoUnits,
        gpu_diff_safe=true,
    )
    simulate!(sys2, sim, 20)
    return rmsd(sys2.coords, starting_coords)
end

loss(params_dic)

gradient(loss, params_dic)

The ala5.pdb file:

REMARK   1 CREATED WITH OPENMM 7.7, 2022-03-17
ATOM      1  N   ALA A   1      -0.677  -1.230  -0.491  1.00  0.00           N  
ATOM      2  H1  ALA A   1      -1.672  -1.326   0.175  1.00  0.00           H  
ATOM      3  H2  ALA A   1      -0.205  -2.312  -0.284  1.00  0.00           H  
ATOM      4  H3  ALA A   1      -1.142  -1.396  -1.586  1.00  0.00           H  
ATOM      5  CA  ALA A   1      -0.001   0.064  -0.491  1.00  0.00           C  
ATOM      6  HA  ALA A   1      -0.307   0.761  -1.410  1.00  0.00           H  
ATOM      7  C   ALA A   1       1.499  -0.110  -0.491  1.00  0.00           C  
ATOM      8  O   ALA A   1       2.233   0.524  -1.257  1.00  0.00           O  
ATOM      9  CB  ALA A   1      -0.509   0.856   0.727  1.00  0.00           C  
ATOM     10  HB1 ALA A   1      -1.630   1.260   0.586  1.00  0.00           H  
ATOM     11  HB2 ALA A   1       0.147   1.855   0.821  1.00  0.00           H  
ATOM     12  HB3 ALA A   1      -0.513   0.440   1.850  1.00  0.00           H  
ATOM     13  N   ALA A   2       2.031  -0.947   0.335  1.00  0.00           N  
ATOM     14  H   ALA A   2       1.491  -1.234   1.355  1.00  0.00           H  
ATOM     15  CA  ALA A   2       3.481  -1.115   0.335  1.00  0.00           C  
ATOM     16  HA  ALA A   2       3.979  -0.110   0.741  1.00  0.00           H  
ATOM     17  C   ALA A   2       3.979  -1.516  -1.034  1.00  0.00           C  
ATOM     18  O   ALA A   2       4.951  -0.967  -1.565  1.00  0.00           O  
ATOM     19  CB  ALA A   2       3.832  -2.145   1.422  1.00  0.00           C  
ATOM     20  HB1 ALA A   2       3.242  -2.174   2.466  1.00  0.00           H  
ATOM     21  HB2 ALA A   2       3.903  -3.307   1.138  1.00  0.00           H  
ATOM     22  HB3 ALA A   2       4.951  -1.909   1.785  1.00  0.00           H  
ATOM     23  N   ALA A   3       3.371  -2.461  -1.667  1.00  0.00           N  
ATOM     24  H   ALA A   3       2.703  -3.280  -1.122  1.00  0.00           H  
ATOM     25  CA  ALA A   3       3.852  -2.848  -2.990  1.00  0.00           C  
ATOM     26  HA  ALA A   3       4.957  -3.295  -2.907  1.00  0.00           H  
ATOM     27  C   ALA A   3       3.863  -1.666  -3.929  1.00  0.00           C  
ATOM     28  O   ALA A   3       4.836  -1.407  -4.647  1.00  0.00           O  
ATOM     29  CB  ALA A   3       2.962  -3.999  -3.492  1.00  0.00           C  
ATOM     30  HB1 ALA A   3       3.402  -4.309  -4.563  1.00  0.00           H  
ATOM     31  HB2 ALA A   3       1.783  -3.954  -3.696  1.00  0.00           H  
ATOM     32  HB3 ALA A   3       3.081  -5.011  -2.859  1.00  0.00           H  
ATOM     33  N   ALA A   4       2.825  -0.902  -3.984  1.00  0.00           N  
ATOM     34  H   ALA A   4       1.758  -1.411  -3.866  1.00  0.00           H  
ATOM     35  CA  ALA A   4       2.836   0.242  -4.892  1.00  0.00           C  
ATOM     36  HA  ALA A   4       2.957  -0.134  -6.020  1.00  0.00           H  
ATOM     37  C   ALA A   4       4.002   1.154  -4.597  1.00  0.00           C  
ATOM     38  O   ALA A   4       4.737   1.590  -5.492  1.00  0.00           O  
ATOM     39  CB  ALA A   4       1.477   0.951  -4.765  1.00  0.00           C  
ATOM     40  HB1 ALA A   4       1.191   1.837  -4.012  1.00  0.00           H  
ATOM     41  HB2 ALA A   4       1.384   1.543  -5.807  1.00  0.00           H  
ATOM     42  HB3 ALA A   4       0.476   0.293  -4.808  1.00  0.00           H  
ATOM     43  N   ALA A   5       4.239   1.491  -3.374  1.00  0.00           N  
ATOM     44  H   ALA A   5       3.314   2.039  -2.863  1.00  0.00           H  
ATOM     45  CA  ALA A   5       5.366   2.373  -3.090  1.00  0.00           C  
ATOM     46  HA  ALA A   5       5.290   3.438  -3.625  1.00  0.00           H  
ATOM     47  C   ALA A   5       6.657   1.787  -3.611  1.00  0.00           C  
ATOM     48  O   ALA A   5       6.692   0.703  -4.205  1.00  0.00           O  
ATOM     49  CB  ALA A   5       5.395   2.620  -1.571  1.00  0.00           C  
ATOM     50  HB1 ALA A   5       4.470   3.196  -1.074  1.00  0.00           H  
ATOM     51  HB2 ALA A   5       6.276   3.405  -1.348  1.00  0.00           H  
ATOM     52  HB3 ALA A   5       5.721   1.705  -0.874  1.00  0.00           H  
ATOM     53  OXT ALA A   5       7.648   2.530  -3.396  1.00  0.00           O  
TER      54      ALA A   5
END

On ChainRules up to 1.42.0 this worked, on 1.43.0-1.44.1 it errors with a different error fixed by https://github.com/JuliaDiff/ChainRules.jl/pull/661, and on 1.44.2 it errors as follows:

ERROR: LoadError: MethodError: no method matching +(::Dict{Any, Any}, ::Dict{Any, Any})
Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...) at ~/soft/julia/julia-1.7.2/share/julia/base/operators.jl:655
  +(::Union{InitialValues.NonspecificInitialValue, InitialValues.SpecificInitialValue{typeof(+)}}, ::Any) at ~/.julia/packages/InitialValues/OWP8V/src/InitialValues.jl:154
  +(::Dict, ::ChainRulesCore.Tangent{P}) where P at ~/.julia/packages/ChainRulesCore/ctmSK/src/tangent_arithmetic.jl:145
  ...
Stacktrace:
  [1] add_sum(x::Dict{Any, Any}, y::Dict{Any, Any})
    @ Base ./reduce.jl:24
  [2] _mapreduce
    @ ./reduce.jl:410 [inlined]
  [3] _mapreduce_dim
    @ ./reducedim.jl:330 [inlined]
  [4] #mapreduce#725
    @ ./reducedim.jl:322 [inlined]
  [5] mapreduce
    @ ./reducedim.jl:322 [inlined]
  [6] #_sum#735
    @ ./reducedim.jl:894 [inlined]
  [7] _sum
    @ ./reducedim.jl:894 [inlined]
  [8] #_sum#734
    @ ./reducedim.jl:893 [inlined]
  [9] _sum
    @ ./reducedim.jl:893 [inlined]
 [10] #sum#732
    @ ./reducedim.jl:889 [inlined]
 [11] sum
    @ ./reducedim.jl:889 [inlined]
 [12] unbroadcast
    @ ~/.julia/dev/ChainRules/src/rulesets/Base/broadcast.jl:348 [inlined]
 [13] map(f::typeof(ChainRules.unbroadcast), t::Tuple{Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{Dict{String, Float64}}}, s::Tuple{Vector{ChainRulesCore.Tangent}, Vector{Dict{Any, Any}}})
    @ Base ./tuple.jl:247
 [14] (::ChainRules.var"#back_generic#1708"{typeof(Molly.inject_interaction), Tuple{Zygote.var"#ad_pullback#50"{Tuple{typeof(Molly.inject_interaction), LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Dict{String, Float64}}, typeof(∂(inject_interaction))}, Zygote.var"#ad_pullback#50"{Tuple{typeof(Molly.inject_interaction), Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Dict{String, Float64}}, typeof(∂(inject_interaction))}}, Tuple{Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{Dict{String, Float64}}}})(dys::ChainRulesCore.Tangent{Any, Tuple{LennardJones{false, Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}})
    @ ChainRules ~/.julia/dev/ChainRules/src/rulesets/Base/broadcast.jl:134
 [15] ZBack
    @ ~/.julia/dev/Zygote/src/compiler/chainrules.jl:206 [inlined]
 [16] (::Zygote.var"#208#209"{Tuple{NTuple{4, Nothing}, Tuple{}}, Zygote.ZBack{ChainRules.var"#back_generic#1708"{typeof(Molly.inject_interaction), Tuple{Zygote.var"#ad_pullback#50"{Tuple{typeof(Molly.inject_interaction), LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Dict{String, Float64}}, typeof(∂(inject_interaction))}, Zygote.var"#ad_pullback#50"{Tuple{typeof(Molly.inject_interaction), Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Dict{String, Float64}}, typeof(∂(inject_interaction))}}, Tuple{Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{Dict{String, Float64}}}}}})(Δ::Tuple{LennardJones{false, Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}})
    @ Zygote ~/.julia/dev/Zygote/src/lib/lib.jl:206
 [17] (::Zygote.var"#1914#back#210"{Zygote.var"#208#209"{Tuple{NTuple{4, Nothing}, Tuple{}}, Zygote.ZBack{ChainRules.var"#back_generic#1708"{typeof(Molly.inject_interaction), Tuple{Zygote.var"#ad_pullback#50"{Tuple{typeof(Molly.inject_interaction), LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Dict{String, Float64}}, typeof(∂(inject_interaction))}, Zygote.var"#ad_pullback#50"{Tuple{typeof(Molly.inject_interaction), Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Dict{String, Float64}}, typeof(∂(inject_interaction))}}, Tuple{Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{Dict{String, Float64}}}}}}})(Δ::Tuple{LennardJones{false, Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
 [18] Pullback
    @ ./broadcast.jl:1303 [inlined]
 [19] (::typeof(∂(broadcasted)))(Δ::Tuple{LennardJones{false, Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [20] Pullback
    @ ~/.julia/dev/Molly/src/gradients.jl:98 [inlined]
 [21] (::typeof(∂(inject_gradients)))(Δ::Tuple{Vector{Atom{Float64, Float64, Float64, Float64}}, Tuple{LennardJones{false, Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{NamedTuple{(:is, :js, :types, :inters), Tuple{Nothing, Nothing, Nothing, Vector{Tuple{Float64, Float64}}}}, NamedTuple{(:is, :js, :ks, :types, :inters), Tuple{Nothing, Nothing, Nothing, Nothing, Vector{Tuple{Float64, Float64}}}}, NamedTuple{(:is, :js, :ks, :ls, :types, :inters), Tuple{Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Tuple{NTuple{6, Int64}, NTuple{6, Float64}, NTuple{6, Float64}, Bool}}}}, NamedTuple{(:is, :js, :ks, :ls, :types, :inters), Tuple{Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Tuple{NTuple{6, Int64}, NTuple{6, Float64}, NTuple{6, Float64}, Bool}}}}}, Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [22] Pullback
    @ ~/.julia/dev/Molly/src/gradients.jl:92 [inlined]
 [23] (::typeof(∂(inject_gradients)))(Δ::Tuple{Vector{Atom{Float64, Float64, Float64, Float64}}, Tuple{LennardJones{false, Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{NamedTuple{(:is, :js, :types, :inters), Tuple{Nothing, Nothing, Nothing, Vector{Tuple{Float64, Float64}}}}, NamedTuple{(:is, :js, :ks, :types, :inters), Tuple{Nothing, Nothing, Nothing, Nothing, Vector{Tuple{Float64, Float64}}}}, NamedTuple{(:is, :js, :ks, :ls, :types, :inters), Tuple{Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Tuple{NTuple{6, Int64}, NTuple{6, Float64}, NTuple{6, Float64}, Bool}}}}, NamedTuple{(:is, :js, :ks, :ls, :types, :inters), Tuple{Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Tuple{NTuple{6, Int64}, NTuple{6, Float64}, NTuple{6, Float64}, Bool}}}}}, Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [24] Pullback
    @ ~/dms/molly_dev/grad_err.jl:31 [inlined]
 [25] (::typeof(∂(loss)))(Δ::Float64)
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [26] (::Zygote.var"#60#61"{typeof(∂(loss))})(Δ::Float64)
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:45
 [27] gradient(f::Function, args::Dict{String, Float64})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:97
 [28] top-level scope
    @ ~/dms/molly_dev/grad_err.jl:50

Adding ChainRulesCore.@opt_out rrule(cfg::Zygote.ZygoteRuleConfig, ::typeof(Broadcast.broadcasted), ::Broadcast.BroadcastStyle, f::F, args::Vararg{Any, N}) where {F, N} to Molly as suggested by @mcabbott gives a different error:

ERROR: LoadError: Need an adjoint for constructor Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(Molly.inject_interaction), Tuple{Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Dict{String, Float64}}, Tuple{System{3, true, Float64, false, Vector{Atom{Float64, Float64, Float64, Float64}}, Vector{AtomData}, Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{InteractionList2Atoms{Vector{HarmonicBond{Float64, Float64}}}, InteractionList3Atoms{Vector{HarmonicAngle{Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}}, Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Vector{SVector{3, Float64}}, Vector{SVector{3, Float64}}, CubicBoundary{Float64}, DistanceVecNeighborFinder{Float64, BitMatrix, Matrix{Int64}}, Tuple{}, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}, Float64}}}}. Gradient is of type Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.Jnew{Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(Molly.inject_interaction), Tuple{Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Dict{String, Float64}}, Tuple{System{3, true, Float64, false, Vector{Atom{Float64, Float64, Float64, Float64}}, Vector{AtomData}, Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{InteractionList2Atoms{Vector{HarmonicBond{Float64, Float64}}}, InteractionList3Atoms{Vector{HarmonicAngle{Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}}, Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Vector{SVector{3, Float64}}, Vector{SVector{3, Float64}}, CubicBoundary{Float64}, DistanceVecNeighborFinder{Float64, BitMatrix, Matrix{Int64}}, Tuple{}, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}, Float64}}}}, Nothing, false})(Δ::Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}})
    @ Zygote ~/.julia/dev/Zygote/src/lib/lib.jl:327
  [3] (::Zygote.var"#1948#back#224"{Zygote.Jnew{Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(Molly.inject_interaction), Tuple{Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Dict{String, Float64}}, Tuple{System{3, true, Float64, false, Vector{Atom{Float64, Float64, Float64, Float64}}, Vector{AtomData}, Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{InteractionList2Atoms{Vector{HarmonicBond{Float64, Float64}}}, InteractionList3Atoms{Vector{HarmonicAngle{Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}}, Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Vector{SVector{3, Float64}}, Vector{SVector{3, Float64}}, CubicBoundary{Float64}, DistanceVecNeighborFinder{Float64, BitMatrix, Matrix{Int64}}, Tuple{}, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}, Float64}}}}, Nothing, false}})(Δ::Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [4] Pullback
    @ ./broadcast.jl:170 [inlined]
  [5] (::typeof(∂(Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(Molly.inject_interaction), Tuple{Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Dict{String, Float64}}, Tuple{System{3, true, Float64, false, Vector{Atom{Float64, Float64, Float64, Float64}}, Vector{AtomData}, Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{InteractionList2Atoms{Vector{HarmonicBond{Float64, Float64}}}, InteractionList3Atoms{Vector{HarmonicAngle{Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}}, Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Vector{SVector{3, Float64}}, Vector{SVector{3, Float64}}, CubicBoundary{Float64}, DistanceVecNeighborFinder{Float64, BitMatrix, Matrix{Int64}}, Tuple{}, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}, Float64}}}})))(Δ::Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
  [6] Pullback
    @ ./broadcast.jl:179 [inlined]
  [7] (::typeof(∂(Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}})))(Δ::Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
  [8] Pullback
    @ ./broadcast.jl:179 [inlined]
  [9] (::typeof(∂(Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}})))(Δ::Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [10] Pullback
    @ ./broadcast.jl:1305 [inlined]
 [11] (::Zygote.var"#208#209"{Tuple{NTuple{4, Nothing}, Tuple{Nothing}}, typeof(∂(broadcasted))})(Δ::Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}})
    @ Zygote ~/.julia/dev/Zygote/src/lib/lib.jl:206
 [12] (::Zygote.var"#1914#back#210"{Zygote.var"#208#209"{Tuple{NTuple{4, Nothing}, Tuple{Nothing}}, typeof(∂(broadcasted))}})(Δ::Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
 [13] Pullback
    @ ./broadcast.jl:1303 [inlined]
 [14] (::typeof(∂(broadcasted)))(Δ::Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [15] Pullback
    @ ~/.julia/dev/Molly/src/gradients.jl:108 [inlined]
 [16] (::typeof(∂(inject_gradients)))(Δ::Tuple{Vector{Atom{Float64, Float64, Float64, Float64}}, Tuple{LennardJones{false, Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{NamedTuple{(:is, :js, :types, :inters), Tuple{Nothing, Nothing, Nothing, Vector{Tuple{Float64, Float64}}}}, NamedTuple{(:is, :js, :ks, :types, :inters), Tuple{Nothing, Nothing, Nothing, Nothing, Vector{Tuple{Float64, Float64}}}}, NamedTuple{(:is, :js, :ks, :ls, :types, :inters), Tuple{Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Tuple{NTuple{6, Int64}, NTuple{6, Float64}, NTuple{6, Float64}, Bool}}}}, NamedTuple{(:is, :js, :ks, :ls, :types, :inters), Tuple{Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Tuple{NTuple{6, Int64}, NTuple{6, Float64}, NTuple{6, Float64}, Bool}}}}}, Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [17] Pullback
    @ ~/.julia/dev/Molly/src/gradients.jl:92 [inlined]
 [18] (::typeof(∂(inject_gradients)))(Δ::Tuple{Vector{Atom{Float64, Float64, Float64, Float64}}, Tuple{LennardJones{false, Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{Nothing, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{NamedTuple{(:is, :js, :types, :inters), Tuple{Nothing, Nothing, Nothing, Vector{Tuple{Float64, Float64}}}}, NamedTuple{(:is, :js, :ks, :types, :inters), Tuple{Nothing, Nothing, Nothing, Nothing, Vector{Tuple{Float64, Float64}}}}, NamedTuple{(:is, :js, :ks, :ls, :types, :inters), Tuple{Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Tuple{NTuple{6, Int64}, NTuple{6, Float64}, NTuple{6, Float64}, Bool}}}}, NamedTuple{(:is, :js, :ks, :ls, :types, :inters), Tuple{Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Tuple{NTuple{6, Int64}, NTuple{6, Float64}, NTuple{6, Float64}, Bool}}}}}, Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [19] Pullback
    @ ~/dms/molly_dev/grad_err.jl:31 [inlined]
 [20] (::typeof(∂(loss)))(Δ::Float64)
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [21] (::Zygote.var"#60#61"{typeof(∂(loss))})(Δ::Float64)
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:45
 [22] gradient(f::Function, args::Dict{String, Float64})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:97
 [23] top-level scope
    @ ~/dms/molly_dev/grad_err.jl:50

Commenting out either the "inter_LJ_weight_14" => 0.5, or "inter_CO_weight_14" => 0.5, lines makes it work, presumably because no dictionaries have to be added in the case of one gradient.

mcabbott commented 2 years ago

Can reproduce the errors shown. Note BTW that Molly only seems to work (for me) on 1.7, not on Julia 1.8. (But maybe it's an Apple M1 problem.)

Without the opt-out, @debug prints some things. And the shortcut in rrule_via_ad does not seem to be involved in calling these rules:

(jl_GvojyS) pkg> st
      Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_GvojyS/Project.toml`
  [082447d4] ChainRules v1.44.2
  [aa0f7f06] Molly v0.13.0
  [e88e6eb3] Zygote v0.6.43

julia> ENV["JULIA_DEBUG"] = ChainRules;

# This isn't in fact called:
julia> @eval Zygote function ChainRulesCore.rrule_via_ad(config::ZygoteRuleConfig, f_args...; kwargs...)
           # first check whether there is an `rrule` which handles this directly
           direct = rrule(config, f_args...; kwargs...)
           f = f_args[1]
           direct === nothing || (@info "rrule shortcut" f; return direct)

           # create a closure to work around _pullback not accepting kwargs
           # but avoid creating a closure unnecessarily (pullbacks of closures do not infer)
           y, pb = if !isempty(kwargs)
               kwf() = first(f_args)(Base.tail(f_args)...; kwargs...)
               _y, _pb = _pullback(config.context, kwf)
               _y, Δ -> first(_pb(Δ)).f_args  # `first` should be `only`
           else
               _pullback(config.context, f_args...)
           end

           ad_pullback(Δ) = zygote2differential(pb(wrap_chainrules_output(Δ)), f_args)
           return y, ad_pullback
       end;

# Note that the T == Bool path is called many times, no @info here
julia> @eval Zygote @adjoint function broadcasted(::AbstractArrayStyle, f::F, args...) where {F}
         T = Broadcast.combine_eltypes(f, args)
         # Avoid generic broadcasting in two easy cases:
         if T == Bool
           return (f.(args...), _ -> nothing)
         elseif T <: Real && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args) && !isderiving()
           @info "Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> broadcast_forward" f
           return broadcast_forward(f, args...)
         end
         len = inclen(args)
         @info "Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks" f
         y∂b = _broadcast((x...) -> _pullback(__context__, f, x...), args...)
         y = map(first, y∂b)
         function ∇broadcasted(ȳ)
           dxs_zip = map(((_, pb), ȳ₁) -> pb(ȳ₁), y∂b, ȳ)
           dxs = ntuple(len) do i
             collapse_nothings(map(StaticGetter{i}(), dxs_zip))
           end
           (nothing, accum_sum(dxs[1]), map(unbroadcast, args, Base.tail(dxs))...)
         end
         return y, ∇broadcasted
       end

# Random easy test
julia> gradient(xs -> sum((x -> sin(x)).(xs)), [1,2,3]/4)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> broadcast_forward
└   f = #6 (generic function with 1 method)
([0.9689124217106447, 0.8775825618903728, 0.7316888688738209],)

# MWE from above
julia> @time gradient(loss, params_dic)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└   f = inject_atom (generic function with 1 method)
┌ Debug: split broadcasting generic
│   f = inject_interaction (generic function with 7 methods)
│   N = 2
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:126
┌ Debug: split broadcasting generic
│   f = inject_interaction_list (generic function with 4 methods)
│   N = 3
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:126
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└   f = inject_interaction (generic function with 7 methods)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└   f = inject_interaction (generic function with 7 methods)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└   f = inject_interaction (generic function with 7 methods)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└   f = inject_interaction (generic function with 7 methods)
┌ Debug: split broadcasting generic
│   f = inject_interaction (generic function with 7 methods)
│   N = 3
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:126
ERROR: MethodError: no method matching +(::Dict{Any, Any}, ::Dict{Any, Any})

(Error as above.)

With the opt-out, it's the second broadcast above, with inject_interaction, which fails:

julia> ChainRulesCore.@opt_out rrule(cfg::Zygote.ZygoteRuleConfig, ::typeof(Broadcast.broadcasted), ::Broadcast.BroadcastStyle, f::F, args::Vararg{Any, N}) where {F, N}

julia> gradient(xs -> sum((x -> sin(x)).(xs)), [1,2,3]/4)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> broadcast_forward
└   f = #18 (generic function with 1 method)
([0.9689124217106447, 0.8775825618903728, 0.7316888688738209],)

julia> gradient((xs, y) -> sum((x -> sin(x/y)).(xs)), [1,2,3]/4, 5/6)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└   f = #22 (generic function with 1 method)
([1.146403786950727, 0.9904027378916139, 0.7459319619247974], -1.609501544552504)

julia> @time gradient(loss, params_dic)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└   f = inject_atom (generic function with 1 method)
ERROR: Need an adjoint for constructor Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(Molly.inject_interaction), Tuple{Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Dict{String, Float64}}, Tuple{System{3, true, Float64, false, Vector{Atom{Float64, Float64, Float64, Float64}}, Vector{AtomData}, Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{InteractionList2Atoms{Vector{HarmonicBond{Float64, Float64}}}, InteractionList3Atoms{Vector{HarmonicAngle{Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}}, Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Vector{SVector{3, Float64}}, Vector{SVector{3, Float64}}, CubicBoundary{Float64}, DistanceVecNeighborFinder{Float64, BitMatrix, Matrix{Int64}}, Tuple{}, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}, Float64}}}}. Gradient is of type Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.Jnew{Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(Molly.inject_interaction), Tuple{Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Dict{String, Float64}}, Tuple{System{3, true, Float64, false, Vector{Atom{Float64, Float64, Float64, Float64}}, Vector{AtomData}, Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{InteractionList2Atoms{Vector{HarmonicBond{Float64, Float64}}}, InteractionList3Atoms{Vector{HarmonicAngle{Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}}, Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Vector{SVector{3, Float64}}, Vector{SVector{3, Float64}}, CubicBoundary{Float64}, DistanceVecNeighborFinder{Float64, BitMatrix, Matrix{Int64}}, Tuple{}, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}, Float64}}}}, Nothing, false})(Δ::Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/lib/lib.jl:326

So plausibly it's a failure of the opt-out mechanism? Where above it uses the rrule, here it instead seems to not find the @adjoint rule at all, perhaps?

mcabbott commented 2 years ago

Now I see. The CR rule accepts any BroadcastStyle, to handle tuples too, while the Zygote one restricts to AbstractArrayStyle. The cases where the CR rule is called all have Broadcast.Style{Tuple}():

julia> @eval ChainRules function rrule(cfg::RCR, ::typeof(broadcasted), style::BroadcastStyle, f::F, args::Vararg{Any,N}) where {F,N}
       @debug "called the rrule!" f style
           T = Broadcast.combine_eltypes(f, args)
           if T === Bool  # TODO use nondifftype here
               # 1: Trivial case: non-differentiable output, e.g. `x .> 0`
               @debug("split broadcasting trivial", f, T)
               bc_trivial_back(_) = (TRI_NO..., ntuple(Returns(ZeroTangent()), length(args))...)
               return f.(args...), bc_trivial_back
           elseif T <: Number && may_bc_derivatives(T, f, args...)
               # 2: Fast path: use arguments & result to find derivatives.
               return split_bc_derivatives(f, args...)
           elseif T <: Number && may_bc_forwards(cfg, f, args...)
               # 3: Future path: use `frule_via_ad`?
               return split_bc_forwards(cfg, f, args...)
           else
               # 4: Slow path: collect all the pullbacks & apply them later.
               return split_bc_pullbacks(cfg, f, args...)
           end
       end
rrule (generic function with 1065 methods)

julia> @time gradient(loss, params_dic)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└   f = inject_atom (generic function with 1 method)
┌ Debug: called the rrule!
│   f = inject_interaction (generic function with 7 methods)
│   style = Base.Broadcast.Style{Tuple}()
└ @ ChainRules REPL[37]:2
┌ Debug: split broadcasting generic
│   f = inject_interaction (generic function with 7 methods)
│   N = 2
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:126
┌ Debug: called the rrule!
│   f = inject_interaction_list (generic function with 4 methods)
│   style = Base.Broadcast.Style{Tuple}()
└ @ ChainRules REPL[37]:2
┌ Debug: split broadcasting generic
│   f = inject_interaction_list (generic function with 4 methods)
│   N = 3
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:126
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└   f = inject_interaction (generic function with 7 methods)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└   f = inject_interaction (generic function with 7 methods)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└   f = inject_interaction (generic function with 7 methods)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└   f = inject_interaction (generic function with 7 methods)
┌ Debug: called the rrule!
│   f = inject_interaction (generic function with 7 methods)
│   style = Base.Broadcast.Style{Tuple}()
└ @ ChainRules REPL[37]:2
┌ Debug: split broadcasting generic
│   f = inject_interaction (generic function with 7 methods)
│   N = 3
└ @ ChainRules ~/.julia/packages/ChainRules/DUopG/src/rulesets/Base/broadcast.jl:126
ERROR: MethodError: no method matching +(::Dict{Any, Any}, ::Dict{Any, Any})
mcabbott commented 2 years ago

This is not solved by changing the rrule's signature to match the @adjoint rule, and reject Broadcast.Style{Tuple}. With https://github.com/JuliaDiff/ChainRules.jl/commit/6e383c10f8dfc0731be64a51bc3a33a6c7d21f5b you get the same error as with the @opt_out:

julia> @time gradient(loss, params_dic)
┌ Info: Zygote's @adjoint broadcasted(::AbstractArrayStyle, ... -> pullbacks
└   f = inject_atom (generic function with 1 method)
ERROR: Need an adjoint for constructor Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(Molly.inject_interaction), Tuple{Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Dict{String, Float64}}, Tuple{System{3, true, Float64, false, Vector{Atom{Float64, Float64, Float64, Float64}}, Vector{AtomData}, Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{InteractionList2Atoms{Vector{HarmonicBond{Float64, Float64}}}, InteractionList3Atoms{Vector{HarmonicAngle{Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}}, Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Vector{SVector{3, Float64}}, Vector{SVector{3, Float64}}, CubicBoundary{Float64}, DistanceVecNeighborFinder{Float64, BitMatrix, Matrix{Int64}}, Tuple{}, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}, Float64}}}}. Gradient is of type Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.Jnew{Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(Molly.inject_interaction), Tuple{Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Tuple{Dict{String, Float64}}, Tuple{System{3, true, Float64, false, Vector{Atom{Float64, Float64, Float64, Float64}}, Vector{AtomData}, Tuple{LennardJones{false, DistanceCutoff{Float64, Float64, Float64}, Float64, Int64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}, Coulomb{DistanceCutoff{Float64, Float64, Float64}, Float64, Float64, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}}}, Tuple{InteractionList2Atoms{Vector{HarmonicBond{Float64, Float64}}}, InteractionList3Atoms{Vector{HarmonicAngle{Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}, InteractionList4Atoms{Vector{PeriodicTorsion{6, Float64, Float64}}}}, Tuple{ImplicitSolventGBN2{Float64, Float64, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Matrix{Int64}, Matrix{Float64}, Matrix{Float64}}}, Vector{SVector{3, Float64}}, Vector{SVector{3, Float64}}, CubicBoundary{Float64}, DistanceVecNeighborFinder{Float64, BitMatrix, Matrix{Int64}}, Tuple{}, Unitful.FreeUnits{(), NoDims, nothing}, Unitful.FreeUnits{(), NoDims, nothing}, Float64}}}}, Nothing, false})(Δ::Tuple{NamedTuple{(:offset_radii, :scaled_offset_radii, :solvent_dielectric, :solute_dielectric, :kappa, :offset, :dist_cutoff, :use_ACE, :αs, :βs, :γs, :probe_radius, :sa_factor, :factor_solute, :factor_solvent, :is, :js, :d0s, :m0s, :neck_scale, :neck_cut), Tuple{Vector{Float64}, Vector{Float64}, Nothing, Nothing, Float64, Float64, Nothing, Nothing, Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64, Float64, Float64, Float64, Nothing, Nothing, Matrix{Float64}, Matrix{Float64}, Float64, Float64}}})
oxinabox commented 2 years ago

This is a zygote bug. I wish i i could transfer this issue there