Closed samuela closed 4 years ago
ArrayPartitions as u0
work, and that's tested in https://github.com/SciML/DiffEqSensitivity.jl/blob/master/test/local_sensitivity/second_order_odes.jl since second order ODEs solve an ODEProblem defined with a u0
that is an ArrayPartition
. So this issue is specific to ArrayPartition{ArrayPartition}
, and it's specifically a Zygote issue. ReverseDiff.jl doesn't support custom array types at all, so it's not really the issue other than the fact that it has a hard limitation that it won't work 🤷 , so we'll stick to Zygote here.
And it's a bug because Zygote's type handling for abstract array types is incorrect. Essentially, when it sees AbstractArrays
, it thinks it has license to change those to Array
or CuArray
, which is does implicitly inside of its adapt
calls, but it doesn't and it shouldn't. How is ArrayPartition
supported then? Well, because it's common enough in DiffEq, I specifically worked around this bug here:
However, given that there are infinitely many types, I don't think doing the ArrayPartition{ArrayPartition} case is going to be viable, so I think I'll call in @DhairyaLGandhi that we need a better solution to this. This is a very good MWE of what I'd say is a direct consequence of incorrect decisions made when defining the interface in https://fluxml.ai/Zygote.jl/dev/adjoints/#Custom-Types-1 . Specifically, @adjoint width(p::Point) = p.x, x̄ -> (Point(x̄, 0),)
is not a good or composable idea because there are cases where you cannot just replace other values in a type with 0: that type might require the field to be non-numeric, or here, it might need to be sized in a way that you cannot easily define in the adjoint. Because of this sizing issue, there does not exist a completely usable definition for ArrayPartition's adjoints in the Zygote system, hence this hack to make this case work. Another case where this shows up is in Dual numbers, where dual numbers should support multi-partial duals, but exactly this same issue happens: you cannot correctly define adjoints of multi-partial duals in the current system. This is the reason why the current support:
https://github.com/SciML/DiffEqFlux.jl/blob/master/src/DiffEqFlux.jl#L57-L68
only handles single partial dual numbers (and asserts this). And then, because the pullback is wants to give a Dual number (which is the root of all of the problems IMO), the derivative a function that builds and uses a dual will be a dual number. Or in another way of writing it: the derivative of a real input function with a real output which has an intermediate portion that is a dual number will give the gradient as a dual number (???), violating all of the standard Zygote usage assumptions. You then have to work around it:
https://github.com/SciML/DiffEqFlux.jl/blob/master/src/DiffEqFlux.jl#L57-L68
This is what was mentioned in https://github.com/FluxML/Zygote.jl/pull/510#issuecomment-593116576 .
What these examples are really demonstrating is that not all array or number types can actually be supported in the current interface for defining adjoints of constructors, so IMO this is probably the main thing to fix in Zygote since it is the root of a lot of compatability issues. Since it would be necessarily breaking (since it would break adjoint definitions), it may be good to line up a change here with @keno's major AD change. We should discuss this on one of the next community calls, but dual numbers are a good example of a number type that can't really be handled here and ArrayPartition is an example of an Array, where exactly this problem of nesting ArrayPartition{ArrayPartition} comes in because you cannot necessarily define zero
just from type information (which is true for a lot of array and number types!)
Potentially related to https://github.com/SciML/DifferentialEquations.jl/issues/679, but I think it's its own thing. Basically ArrayPartitions, DynamicalODEProblems, and DiffEqSensitivity.jl is an explosive combination. I've come across two types of errors so far, listed in the
dyn_x
function.gives me