microsoft / knossos-ksc

Compiler with automatic differentiation
Other
45 stars 10 forks source link

Implement vmap #1019

Closed awf closed 2 years ago

awf commented 2 years ago

This PR implements vmap. Given function

@knossos.vmap
def foo(a : Tensor[N,M], b : float, c : Tensor[L]) -> Tensor[R]

where the Tensor[N,M] annotation means a 2-dimensional tensor, then vmap creates a function

def foo(a : Tensor[B,N,M], b : Tensor[B], c : Tensor[B,L]) -> Tensor[B,R]

mapping calls to f over the "batch" dimension B. The gradient is also appropriately mapped.

TODOS:

Benchmarks (with generate_lm = True)

Size 10x4x4

Knossos beats everything... image

Size 1000x4x4

Knossos beats PyTorch nice by 100x, but for backwards, which is the one we most care about, pytorch hand-vectorized is 4x faster on CPU, and close to that on GPU vs CPU image

Size 1000x16x16

Knossos beats PyTorch nice by 15x, but is 18x slower than hand-vectorized image

awf commented 2 years ago

Todo: reset allocation every iter

awf commented 2 years ago

Todo: reset allocation every iter

Now resets the KS allocator after every inplace_add. This makes no difference for vsqrl, as a KS function returning a float always resets the allocator, but is the right thing to do.