Closed zsteve closed 3 years ago
This is way too much code for a MWE, and the stacktrace seems reminiscent of some recent issues where Zygote wasn't being used at all (note the presence of the old Flux AD, Tracker). If you're not able to create a diffeq-less MWE, I'd recommend asking in their channels to see if anyone has seen this before (I believe they have) and how best to address it.
You are right, I should have tried to narrow the scope to exclude diffeq, sorry! I was under the impression I was making some trivial mistake since I'm new to Flux. Will close and re-post a cleaned issue once ready.
Hi,
I'm trying to set up a DiffEqFlux model where there I have custom layer that computes the gradient (w.r.t. the input variable) of a scalar-valued neural network, but I am encountering an error when I try to train the model (presumably because gradients w.r.t. parameters are wanted).
Update: I just saw #1518 and it seems that Zygote has issues with nesting AD. However, I do not think that it works to just take derivatives to both inputs and parameters at the same time, because the gradient w.r.t. input is fed into a SDE solver. Any help or pointers would be appreciated!
The code I've got is below, where I'm using
Flux.gradient
to produce the custom gradient layer (I've also tried ForwardDiff, which throws a different error suggesting that ForwardDiff can't be used within Zygote). I am using FastChain as suggested here to allow the parameters from the potential to pass through to the drift.Seems that the parameters are working out for
neuralsde
object:and
sol = neuralsde(randn(1, 100))
also works, so sampling from the SDE is not a problem, only training. Suspect this has to do with Zygote trying to AD (w.r.t. network parameters) through the output of Flux.gradient (w.r.t. input variable).The minimal training code I am using is
The error is as below.