Open jkshtj opened 1 year ago
@BradLarson @asl Could you guys take a look at this?
It would be helpful if the testcase could be self-contained.
But judging from the rough sketch presented here: I won't bother about top-level (SM
) VJP and pullbacks being inlined or not. We actually do not need this. What is important are the internals of SM
:
SM
?SM
's VJP (the pullback) properly specialized?
And if not – what prevents the optimization of SM_VJP
itself?In any case, it seems you're looking from the wrong direction. We do not want SM
to be inlined at the first place. Only if it's beneficial / profitable. For now it does not look like it is profitable as the function itself is huge. What we need is to optimize SM
by itself (certainly, it may happen that it will be profitable afterwards, but not vice-versa).
In particular, it seems we can think about the following optimization: consider a partial_apply
whose argument is another partial_apply
. Then we can specialize the callee, so we can eliminate the inner partial_apply
. Outer one will be capture the arguments of the inner ones. This way instead of lots of small context memory allocations / releases we'd simply have a single larger one.
Essentially, this would look like as if autodiff-code was run on the function with nested function calls inlined.
Based on offline discussion with @asl we have the following work items.
Extend the existing closure-specialization optimization to work not just on functions that are called explicitly using apply
s but also functions that are returned from other functions.
Modify (if need be) the inlining cost benefit analysis to award more benefits to VJPs whose callers are other VJPs.
@asl here's the full code for the benchmark we have been using internally.
import _Differentiation
import Foundation
let dt: Float = 0.1
let π = Float.pi
struct SP: Differentiable{
var t: TT = TT()
var s: ST = ST()
var q: QT = QT()
var tn: TnT = TnT()
var stt: Float
}
struct TT: Differentiable
{
//@noDerivative
var sp: Float = 0.50292
//@noDerivative
var dm: Float = 0.019
//@noDerivative
var tk: Float = 0.001588
//@noDerivative
var rt: Float = 2.43
}
struct ST: Differentiable
{
var tp: Float = 21.1111111
//@noDerivative
var ar: Float = 100.0
//@noDerivative
var Cp: Float = 0.2
//@noDerivative
var dt: Float = 2242.58
//@noDerivative
var tk: Float = 0.101
var nuuu: Float = 10.0
var nuuuu: Float = 10.0
}
struct QT: Differentiable
{
var pw: Float = 0.0
var tp: Float = 60.0
var fw: Float = 0.0006309
var dt: Float = 1000.0
var Cp: Float = 4180.0
var nuuu: Float = 10.0
var nuuuu: Float = 10.0
var nuuuuu: Float = 10.0
}
struct TnT: Differentiable
{
var tp: Float = 70.0
//@noDerivative
var vl: Float = 0.0757082
//@noDerivative
var Cp: Float = 4180.000
//@noDerivative
var dt: Float = 1000.000
//@noDerivative
var ms: Float = 75.708
var nuuu: Float = 10.0
var nuuuu: Float = 10.0
var nuuuuu: Float = 10.0
var nuuuuuu: Float = 10.0
}
////////////// Computations ///////////////////////////////////////////////////////
@differentiable(reverse)
func CRT(fr: ST, tu: TT, qt: QT) -> Float
{
let gc: Float = 10.0
let ts = (fr.ar/tu.sp) * π * tu.dm
let rb = tu.rt * tu.tk / ts
let rc = rb * gc
return rc
}
struct QP: Differentiable {
var qt: QT
var pw: Float
}
@differentiable(reverse)
func CLP(f: ST, t: TT, qt: QT) -> QP
{
let rb = CRT(fr: f, tu: t, qt: qt)
let ct: Float = 1/rb
let d = f.tp - qt.tp
let pw = d * ct
var u = qt
u.pw = pw
let l = -pw
return QP(qt: u, pw: l)
}
@differentiable(reverse)
func UQ(q: QT) -> QT
{
let w = (q.fw * dt)
let m = (w * q.dt)
let e = q.pw * dt
let t = e / q.Cp / m
var u = q
u.tp = q.tp + t
u.pw = 0
return u
}
@differentiable(reverse)
func UB(pw: Float, f: ST) -> ST
{
var u = f
let v = f.ar * f.tk
let m = v * f.dt
u.tp = f.tp + ((pw * dt) / f.Cp / m)
return u
}
struct TQ: Differentiable{
var t: TnT
var q: QT
}
@differentiable(reverse)
func UST(s: TnT, q: QT) -> TQ
{
var u = s
var uq = q
let m = q.fw * q.dt
let dt = s.tp - q.tp
let p = dt * m * q.Cp
uq.pw = p
let tm = s.vl * s.dt
let r = (p * dt) / s.Cp / tm
u.tp = s.tp + r
return TQ(t: u, q: uq)
}
//-----------------------------------------------------------------------
@differentiable(reverse)
@inlinable func absDifferentiable(_ value: Float) -> Float {
if value < 0 {
return -value
}
return value
}
func LC(p: Float, gt: Float) -> Float{
let diff = p - gt
return absDifferentiable(diff)
}
@differentiable(reverse)
func SM(s: SP) -> Float{
let pt = s.t
var sb = s.s
var tn = s.tn
var q = s.q
sb.tp = s.stt //.scalarized
let taq = UST(s: tn, q: q)
tn = taq.t
q = taq.q
q = UQ(q: q)
let p = CLP(f: sb, t: pt, qt: q)
q = p.qt
let b = p.pw
q = UQ(q: q)
sb = UB(pw: b, f: sb)
return sb.tp
}
@differentiable(reverse)
func FP(s: SP) -> Float {
let p = SM(s: s)
let r = LC(p: p, gt: 27.344767)
return r
}
@inline(never)
@_silgen_name("foo")
func foo() -> SP.TangentVector {
let s = SP(stt: 33.3)
let (_, pb) = valueWithPullback(at: s, of: FP)
let grad = pb(1)
return grad
}
As we discussed, I modified the current closure-spec optimization to handle "returned" closures and the performance of the benchmark has improved -- the reverse to forward ratio has been cut in half(a little more actually). And I think with better placement of the optimization in the pipeline the performance can be improved further.
@jkshtj do you have a PR to look at?
We want to investigate possible optimizations in one of our key, internal benchmarks for Swift AD, through a combination of inlining and closure-sepcialization. We are specifically pointing out those 2 optimizations because they can help us get rid of the memory allocations made by Swift AD for creating pullback closures.
The following code is representative of the structure of the benchmark. Other than the code shown in the functions, the size of the functions can be assumed to be the same, i.e., think of
// ...
as representing a constant number of inline differentiable operations. No control-flow is involved.Using the existing compiler optimizations, the top-level function
foo
ends up looking something like:What we instead want, is one of the following outcomes.
VJP and pullback of
SM
are fully inlined intofoo
. This way thepartial_apply
s of intermediate pullbacks inSM_VJP
should get constant-folded into the correspondingapply
s of the pullbacks inSM_PB
.SM_VJP
is fully inlined intofoo
. This way even if we cannot inlineSM_PB
intofoo
we can specialize it to take the values closed over by the intermediate pullback closures instead of the intermediate pullback closures themselves.What's clear?
What's unclear?