dfdx / Yota.jl

Reverse-mode automatic differentiation in Julia
MIT License
158 stars 12 forks source link

Generic broadcast #102

Closed dfdx closed 2 years ago

dfdx commented 2 years ago

As seen in HQDL report, functions that don't have rrule(broadcasted, f, args...) currently fail on Yota. I had [an attempt] to implement generic rrule(), but it was somewhat controversial, so the decision was postponed.

At the moment I see 2 ways to handle it:

  1. Implement generic rrule(::YotaRuleConfig, broadcasted, f, args...), i.e. restricted to YotaRuleConfig.
  2. Have special handling for broadcasted in step_back!().
dfdx commented 2 years ago

Implemented via improved tracing. Not very robust, but mostly works :)