Open awf opened 3 years ago
[Writing in progress]
class KnossosVMapCompiler: compiled = False compilation_future = None def __init__(f): torch_fallback = torch.vmap(f) generic_example_arg = rand_of_standard_shape(f.arg.Type) compilation_future = Future(knossos_vmap_do_it) def __call__(arg): must_redo = arg_characteristics_requiring_recompilation(arg) if must_redo != compiled: compilation_future = compile... if compilation_future has returned: compiled = must_redo py_mod = compilation_future.result if py_mod: return py_mod(arg) else: return torch_fallback(arg) def knossos_vmap(function, generate_lm=True): return KnossosVMapCompiler(function)
[Writing in progress]