pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.39k stars 102 forks source link

CUDA assumption in the ts_compile code #941

Open bwasti opened 2 years ago

bwasti commented 2 years ago

Hey folks, stumbled into a CUDA assumption (on my non-CUDA machine)

Here's the fix for me, but it's obviously not very general

diff --git a/functorch/_src/compilers.py b/functorch/_src/compilers.py
index 10fe42a..83002ac 100644
--- a/functorch/_src/compilers.py
+++ b/functorch/_src/compilers.py
@@ -60,7 +60,7 @@ def ts_compile(fx_g: fx.GraphModule, _) -> Callable:
     for i in range(1000):
         attr = f"_tensor_constant{i}"
         if hasattr(fx_g, attr):
-            setattr(fx_g, attr, getattr(fx_g, attr).cuda())
+            setattr(fx_g, attr, getattr(fx_g, attr))
         else:
             break
vfdev-5 commented 2 years ago

@bwasti where did you encounter that ? Code for main branch is a bit different: https://github.com/pytorch/functorch/blob/76178743084277cf6d7cac752279f905cdd60e13/functorch/_src/compilers.py#L24-L62

Chillee commented 2 years ago

Yeah it's removed on main I believe - this was a hack we used to have to work around nvfuser limitations. But that's been fixed now.