taichi-dev / taichi

Productive, portable, and performant GPU programming in Python.
https://taichi-lang.org
Apache License 2.0
25.51k stars 2.28k forks source link

[RFC] Add forward mode for autodiff #5055

Open erizmr opened 2 years ago

erizmr commented 2 years ago

In this issue, we would like to share a draft implementation plan for the forward mode autodiff.

Background

In general, there are two modes for autodiff: reverse mode and forward mode. The two modes have their advantage in different scenarios. The reverse mode is more efficient when the number of inputs is much more than the outputs (e.g., machine learning cases, thousands of trainable parameters and one scalar loss). On the contrary, the forward mode is more efficient. In addtion, the second-order derivatives can be efficiently computed by combining both the forward and reverse mode.

For a roadmap for the autodiff feature in Taichi, please check out #5050.

Goals

Implementation Roadmap

Discussions

Currently in reverse mode, two kernels (original kernel and grad kernel) for evaluating function values and compute the gradients respectively. However, in forward mode autodiff, the derivatives are computed eagerly during the function evaluating process, i.e., the functions values and gradients can be computed using only kernel. This raise the question whether need to compile one or two kernels.`

Update: three kinds of kernels are generated: primal, forward ad and reverse ad according to different autodiff modes, see #5098.

victoriacity commented 2 years ago

I wonder if explicitly differentiating a function as in JAX will be supported, for example,

@ti.func
def f(x): return x**3 + 2*x**2 - 3*x + 1

dfdx = forward(f)

@ti.kernel
def k() -> float:
    return dfdx(1.0)
k() # returns 4.0
erizmr commented 2 years ago

I think it is possible to support similar features. A naive current Taichi equivalent is:

import taichi as ti

ti.init()

x = ti.field(float, shape=(), needs_grad=True)
y = ti.field(float, shape=(), needs_grad=True)

@ti.kernel
def f(): 
    y[None] += x[None]**3 + 2*x[None]**2 - 3*x[None] + 1

def dfdx(_x):
    x[None] = _x
    y.grad[None] = 1.0
    f.grad()
    return x.grad[None]

print(dfdx(1.0))

For more general case, it may require to specify the input and output if we would like to generate dfdx for the users. A possible implementation might be:

import taichi as ti

ti.init()

x1 = ti.field(float, shape=(), needs_grad=True)
x2 = ti.field(float, shape=(), needs_grad=True)
x3 = ti.field(float, shape=(), needs_grad=True)
y = ti.field(float, shape=(), needs_grad=True)

@ti.kernel
def f(): 
    y[None] += x1[None]**3 + 2*x2[None]**2 - 3*x3[None] + 1

def backward(f, input_field, out_field):
    import numpy as np
    out_field.grad[None] = 1.0
    def _dfdx(inputs):
        for i, x in enumerate(inputs):
            input_field[i].from_numpy(np.array(inputs[i]))
        f.grad()
        ret = []
        for x in input_field:
            ret.append(x.grad.to_numpy())
        return ret
    return _dfdx

dfdx = backward(f, [x1, x2, x3], y)

print(dfdx([1.0, 2.0, 3.0])) # [3, 8, -3]