FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.49k stars 213 forks source link

Zygote.gradient returns `nothing` instead of `NotImplemented #1227

Open simonmandlik opened 2 years ago

simonmandlik commented 2 years ago
using ChainRulesCore, Zygote

f(x) = 2x
ChainRulesCore.rrule(::typeof(f), x) = 2x, d -> (NoTangent(), @not_implemented(""), )

Zygote.gradient(f, 1)

The last line returns (nothing, ) from v0.6.38 onwards but used to return ChainRulesCore.NotImplemented in the previous versions.

Is this intended?

ToucheSir commented 2 years ago

Yes: https://github.com/FluxML/Zygote.jl/pull/1205. As an alternative, you could try using rrule_via_ad(ZygoteRuleConfig(), f, 1). That won't help if a NotImplemented is generated in the process of running AD, but it will short-circuit if it finds a rrule.

If anyone has any ideas on how to refactor the internals such that rrule_via_ad doesn't have to go through an unwrapping/re-wrapping step (i.e. AD uses ChainRules types all the way through), let me know and I'd be happy to discuss + assist in implementation.

simonmandlik commented 2 years ago

Ok, thanks!

ToucheSir commented 1 year ago

Re-opening as a tracker for https://github.com/JuliaDiff/ChainRules.jl/pull/521#issuecomment-1445842030.