microsoft / knossos-ksc

Compiler with automatic differentiation
Other
45 stars 10 forks source link

Async Knossos compiler #835

Open awf opened 3 years ago

awf commented 3 years ago

[Writing in progress]

class KnossosVMapCompiler:
    compiled = False
    compilation_future = None

    def __init__(f):
        torch_fallback = torch.vmap(f)
        generic_example_arg = rand_of_standard_shape(f.arg.Type)
        compilation_future = Future(knossos_vmap_do_it)

    def __call__(arg):
        must_redo = arg_characteristics_requiring_recompilation(arg) 
        if must_redo != compiled:
            compilation_future = compile...

        if compilation_future has returned:
            compiled = must_redo
            py_mod = compilation_future.result

        if py_mod:
            return py_mod(arg)

        else:
            return torch_fallback(arg)

def knossos_vmap(function, generate_lm=True):
    return KnossosVMapCompiler(function)