microsoft / knossos-ksc

Compiler with automatic differentiation
Other
45 stars 10 forks source link

vmap interface to elementwise ops #970

Closed awf closed 2 years ago

awf commented 2 years ago

Replace "elementwise_apply_hack" with a "vmap=True" annotation on knossos.register.

Mostly code deletion thanks to #962 :)

Before:

@knossos.register
def gelu(x: float) -> float:
    return 0.5 * x * (1.0 + erf(x / sqrt(2)))

@knossos.register
def vgelu(x: torch.Tensor):
    return elementwise_apply_hack("gelu", x)

After:

@knossos.vmap
def vgelu(x: float) -> float:
    return 0.5 * x * (1.0 + erf(x / sqrt(2)))

or

@knossos.register
def gelu(x: float) -> float:
    return 0.5 * x * (1.0 + erf(x / sqrt(2)))

vgelu = knossos.vmap(gelu)
awf commented 2 years ago

@cgravill can you take a look at the dotnet install error on this Ubuntu build?

cgravill commented 2 years ago

@cgravill can you take a look at the dotnet install error on this Ubuntu build?

Certainly, will investigate

cgravill commented 2 years ago

It's fixed over here #977 would need to cherry pick or merge from master to here as it was an external change to delist an old minor version. Let me know if you'd like me to.

dcrc2 commented 2 years ago

Currently the test test_ts2k_vrelux succeeds if run on its own, but fails if run after test_ts2k_relux or test_ts2k_relux_grad.

The issue is that there is a clash in the names that we give the PyTorch extensions for this code:

@knossos.register
def relux(x: float):
    ...

vrelux = knossos.vmap(relux)

test_ts2k_relux compiles relux and test_ts2k_vrelux compiles vrelux, but both are given the extension name KscStub_test_torch_frontend_relux. When both tests are run, PyTorch detects that the name is already in use, and renames the second extension to KscStub_test_torch_frontend_relux_v1; however the name KscStub_test_torch_frontend_relux is hard-coded in our source file, so KscStub_test_torch_frontend_relux_v1 is not found.

There are two possible fixes for this:

  1. Change the way the extension name is generated, so that the elementwise version is different.
  2. Don't hard-code the extension name in the C++ file; instead use the TORCH_EXTENSION_NAME macro as recommended at https://pytorch.org/tutorials/advanced/cpp_extension.html#binding-to-python

I'm thinking that we should probably implement both fixes.

dcrc2 commented 2 years ago

I'm thinking that we should probably implement both fixes.

I've pushed a fix using method 1.

cgravill commented 2 years ago

I also agree on doing both fixes - the logic may change on extension names.

dcrc2 commented 2 years ago

I've rebased this on master, squashing the original set of commits.

dcrc2 commented 2 years ago

I also agree on doing both fixes - the logic may change on extension names.

I've made the second fix into a separate PR (#981)