Closed staticfloat closed 1 year ago
but I don't know how to get the adjoint of a user-provided function
You could use Zygote.pullback
to AD through it, which will get an adjoint if it exists. A PR to ChainRules would be well received, see https://github.com/JuliaDiff/ChainRules.jl/issues/85
The low-tech way to implement this is to turn it into broadcasting, as is currently done for sum(::Function, ::CuArray)
here:
https://github.com/FluxML/Zygote.jl/blob/master/src/lib/broadcast.jl#L280-L283
reopened as i had to revert the fix
If you provide both an element-wise function
f
and a dimension specification,mean()
apparently causes array mutation, which breaks Zygote's ability to differentiate:Looking through the adjoints for
mean()
defined inlib/array.jl
, I would guess that the fact that I'm passingabs2
in forf
causes Zygote's implementation to be skipped altogether, and then thedims
kwarg causes us to go down a bad path that involves array mutation. I was going to submit a PR to create a new@adjoint
definition for one that includesf
, but I don't know how to get the adjoint of a user-provided function.