microsoft / knossos-ksc

Compiler with automatic differentiation
Other
45 stars 10 forks source link

Knossos.register, and delayed compilation #960

Closed awf closed 2 years ago

awf commented 2 years ago

Before:

The user needs to find a point at which to explicitly call knossos compilation, and must switch from use of the original function to the compiled version.

This was fine for experimentation, but this PR moves toward a smoother user experience, and is a prelude to vmap

After:

Functions are decorated with @knossos.register, and compilation can be explicit, or delayed to the first call.

@knossos.register
def f(x : torch.Tensor) -> torch.Tensor:
    return x * sin(x)

Endows f with the following behaviours

        y = f(x)       # Fast (C++/CUDA/...) computation of f(x)
        grad(y, x, dy)  # Fast computation of dot(dy, [df_i/dx_j])
        ..
        f.entry_vjp(x, dy)  # Stateless computation of vector-Jacobian product

The implementation delays compilation until the first call, or when "f.compile()" is explicitly called.

awf commented 2 years ago

e.g. does this mean that you have to fix the value of generate_lm? If you want to compare the behavoiur of generate_lm=False with generate_lm=True, you have to write a second copy of the function?

A decorator can also be called directly on a function, so you can write

def fun(..):
  ...
fun_lm = knossos.register(fun, generate_lm=True)
fun_nolm = knossos.register(fun, generate_lm=False)