JuliaDiff / ChainRules.jl

forward and reverse mode automatic differentiation primitives for Julia Base + StdLibs
Other
436 stars 89 forks source link

Add rule for `mean(f, x; dims)` #85

Open nickrobinson251 opened 5 years ago

nickrobinson251 commented 5 years ago

We have rules for mean, except for mean(f, x; dims) which is new as of Julia v1.3

mzgubic commented 2 years ago

The implementation could look a bit like sum(f, xs):

https://github.com/JuliaDiff/ChainRules.jl/blob/ce78d3d3e8aaf6303e1aa7085fdbdfc2d36d1b64/src/rulesets/Base/mapreduce.jl#L69-L94

mcabbott commented 2 years ago

Ideally it would probably share code, have a function which for mean gets scale=1/size(...) or something.

Xref #529 which is trying to re-work that rule.

oxinabox commented 2 years ago

reopened as i had to revert the fix