Open erizmr opened 2 years ago
I wonder if explicitly differentiating a function as in JAX will be supported, for example,
@ti.func
def f(x): return x**3 + 2*x**2 - 3*x + 1
dfdx = forward(f)
@ti.kernel
def k() -> float:
return dfdx(1.0)
k() # returns 4.0
I think it is possible to support similar features. A naive current Taichi equivalent is:
import taichi as ti
ti.init()
x = ti.field(float, shape=(), needs_grad=True)
y = ti.field(float, shape=(), needs_grad=True)
@ti.kernel
def f():
y[None] += x[None]**3 + 2*x[None]**2 - 3*x[None] + 1
def dfdx(_x):
x[None] = _x
y.grad[None] = 1.0
f.grad()
return x.grad[None]
print(dfdx(1.0))
For more general case, it may require to specify the input
and output
if we would like to generate dfdx
for the users. A possible implementation might be:
import taichi as ti
ti.init()
x1 = ti.field(float, shape=(), needs_grad=True)
x2 = ti.field(float, shape=(), needs_grad=True)
x3 = ti.field(float, shape=(), needs_grad=True)
y = ti.field(float, shape=(), needs_grad=True)
@ti.kernel
def f():
y[None] += x1[None]**3 + 2*x2[None]**2 - 3*x3[None] + 1
def backward(f, input_field, out_field):
import numpy as np
out_field.grad[None] = 1.0
def _dfdx(inputs):
for i, x in enumerate(inputs):
input_field[i].from_numpy(np.array(inputs[i]))
f.grad()
ret = []
for x in input_field:
ret.append(x.grad.to_numpy())
return ret
return _dfdx
dfdx = backward(f, [x1, x2, x3], y)
print(dfdx([1.0, 2.0, 3.0])) # [3, 8, -3]
In this issue, we would like to share a draft implementation plan for the forward mode autodiff.
Background
In general, there are two modes for autodiff: reverse mode and forward mode. The two modes have their advantage in different scenarios. The reverse mode is more efficient when the number of inputs is much more than the outputs (e.g., machine learning cases, thousands of trainable parameters and one scalar loss). On the contrary, the forward mode is more efficient. In addtion, the second-order derivatives can be efficiently computed by combining both the forward and reverse mode.
For a roadmap for the autodiff feature in Taichi, please check out #5050.
Goals
forward(reverse())
), preparing for computing second-order derivatives.Implementation Roadmap
Implement forward mode autodiff.
dual
snodes for fields #5083stop gradients
fordual
snodeslazy gradient
fordual
snodes (allocate dual by forward mode context manager) #5146 #5224Design python interface for forward and reverse mode.
- [x] Decouple the#5083adjoint
andgrad
, make thegrad
including bothadjoint
anddual
grad
indicateadjoint
, exposedual
for forward mode #5224Python test cases
Second-order derivative
Discussions
Currently in reverse mode, two kernels (
original kernel
andgrad kernel
) for evaluating function values and compute the gradients respectively. However, in forward mode autodiff, the derivatives are computed eagerly during the function evaluating process, i.e., the functions values and gradients can be computed using only kernel. This raise the question whether need to compile one or two kernels.`Update: three kinds of kernels are generated:
primal
,forward ad
andreverse ad
according to different autodiff modes, see #5098.