microsoft / knossos-ksc

Compiler with automatic differentiation
Other
45 stars 10 forks source link

Couldn't find [aten::div (Tuple (Tensor 1 Float) Float)] applied to ((Tensor 1 Float),Float) #912

Closed ryotatomioka closed 3 years ago

ryotatomioka commented 3 years ago

I tried the following

import sys
from ksc import torch_frontend
sys.path.append("examples/dl-activations")
import gelu
configs = list(gelu.vgelu_bench_configs())
example_inputs = configs[0],
ksdefs = torch_frontend.tsmod2kscdefs(gelu, "vgelu_pytorch", example_inputs)
torch_frontend.ksc_defs_to_module(ksdefs, ksdefs[0], ["sufrev"])

(where tsmod2kscdefs just returns ksc_defs instead of calling ksc_defs_to_autograd_function)

The error seems to suggest that we don't have a broadcasted division of (1D) tensor by a scalar. Shall I just add that to prelude-aten.ks? How do we generally handle broadcasted torch operations?

awf commented 3 years ago

Yes please add to prelude-aten.ks

We explicitly write all combinations of broadcasting because the types need to match. That is, we need

(edef aten::div (Tensor 2 Float) ((a : Float) (b : Tensor 1 Float))
(edef aten::div (Tensor 2 Float) ((a : Float) (b : Tensor 2 Float))
(edef aten::div (Tensor 2 Float) ((a : Tensor 1 Float) (b : Float))
(edef aten::div (Tensor 2 Float) ((a : Tensor 2 Float) (b : Float))
(edef aten::div (Tensor 2 Float) ((a : Tensor 1 Float) (b : Tensor 1 Float))
(edef aten::div (Tensor 2 Float) ((a : Tensor 1 Float) (b : Tensor 2 Float))

for all m and n. We could autogenerate these if we liked, but for now it's manual.

This may seem horribly inefficient, but it is exactly what the C++ compiler ends up doing anyway, so we're just making it clear.

We can use an elementwise helper too in C++ to save code space.

awf commented 3 years ago

Does the error print out the code it wants you to add? I.e. I hope you hit https://github.com/microsoft/knossos-ksc/blob/0b3a4f2aa035b3166e8f76a9dee9b4dba810c4e5/src/python/ksc/type_propagate.py#L302 and not somewhere else.

ryotatomioka commented 3 years ago

Yes sorry I missed that message because it was printed out above lengthy tracebacks

ryotatomioka commented 3 years ago

I think this could be correct?

(def aten::div (Tensor 1 Float) ((a : Tensor 1 Float) (b : Float))
    (build (size a) (lam (inds : Integer )
        (div (index inds a) b))))
(gdef fwd [aten::div (Tuple (Tensor 1 Float) Float)])
(gdef rev [aten::div (Tuple (Tensor 1 Float) Float)])
(gdef suffwdpass [aten::div (Tuple (Tensor 1 Float) Float)])
(gdef sufrevpass [aten::div (Tuple (Tensor 1 Float) Float)])
(gdef sufrev [aten::div (Tuple (Tensor 1 Float) Float)])