Closed zwei-beiner closed 6 months ago
Thanks for catching this!
The package was never tested on a multi-GPU setup unfortunately, it's a current TODO.
Thanks for contributing this minimal example, I should be able to make at least this simple test case work soon.
Not sure if that's helpful, but you could try a temporary workaround pytorch wrapper function that takes all inputs on GPU0 and returns outputs on GPU0 as well (even if computation is done elsewhere).
@zwei-beiner Can you install the version of torch2jax
from this branch: https://github.com/rdyro/torch2jax/tree/multi_gpu_experimental
I was able to make your script run with those changes to torch2jax
, but please let me know if you run into any trouble.
Wow, thanks a lot for the quick follow-up! I will test the new version tomorrow.
Just for your information, what I'm ultimately trying to do is calling this with jax.pmap
:
# Call `torch2jax` once. Then split `inputs` over GPUs and combine results.
jax.pmap(jax_loss)(inputs)
Not sure if I should expect this to work since the external callback is "stateful", i.e. the neural net has some parameters in it which have to be copied onto each device.
At the moment, my plan is to make multiple copies of the neural net on each device "manually" (i.e. keep a list of functions, each compiled with torch2jax
) and then select a function depending on which device the input is located.
Do you think just using pmap
would work automatically, or would I have to do it the above manual way?
Just ran the script with the new branch and I am still getting the error. Here's the full stack trace:
Traceback (most recent call last):
File "new_branch.py", line 37, in <module>
print(jax_loss(array)) # Throws error
^^^^^^^^^^^^^^^
File " venv/lib64/python3.11/site-packages/torch2jax/api.py", line 245, in wrapped_fn
ret = wrapped_fn_flat(*tree_flatten(args)[0])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File " venv/lib64/python3.11/site-packages/torch2jax/api.py", line 97, in wrapped_fn
return torch_prim.bind(*args)
^^^^^^^^^^^^^^^^^^^^^^
File " venv/lib64/python3.11/site-packages/jax/_src/core.py", line 422, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File " venv/lib64/python3.11/site-packages/jax/_src/core.py", line 425, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File " venv/lib64/python3.11/site-packages/jax/_src/core.py", line 913, in process_primitive
return primitive.impl(*tracers, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File " venv/lib64/python3.11/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
outs = fun(*args)
^^^^^^^^^^
RuntimeError: Specified device cuda:0 does not match device of data cuda:1
Exception raised from make_tensor at aten/src/ATen/Functions.cpp:24 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7eff99cb54d7 in venv/lib64/python3.11/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x11a22dc (0x7effc0fe32dc in venv/lib64/python3.11/site-packages/torch/lib/libtorch_cpu.so)
frame #2: apply_torch_call(void**, DynamicTorchCallDescriptor const&) + 0x21f (0x7eff675ebb9f in torch2jax/cpython-311-linux-x86_64/torch2jax_cpp.so)
frame #3: gpu_apply_torch_call(CUstream_st*, void**, char const*, unsigned long) + 0x7a (0x7eff675f0fca in torch2jax/cpython-311-linux-x86_64/torch2jax_cpp.so)
frame #4: <unknown function> + 0x45faaff (0x7f001b4b4aff in venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #5: <unknown function> + 0x45fb4f5 (0x7f001b4b54f5 in venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #6: <unknown function> + 0x554a3e3 (0x7f001c4043e3 in venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #7: <unknown function> + 0x5547b89 (0x7f001c401b89 in venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #8: <unknown function> + 0x5546244 (0x7f001c400244 in venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #9: <unknown function> + 0x5545b8f (0x7f001c3ffb8f in venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #10: <unknown function> + 0x75ceea6 (0x7f001e488ea6 in venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #11: <unknown function> + 0x11f7735 (0x7f00180b1735 in venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #12: <unknown function> + 0x11f8088 (0x7f00180b2088 in venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #13: <unknown function> + 0x11664c5 (0x7f00180204c5 in venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #14: <unknown function> + 0x1168c84 (0x7f0018022c84 in venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #15: <unknown function> + 0x116b1c0 (0x7f00180251c0 in venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #16: <unknown function> + 0x1026341 (0x7f0017ee0341 in venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #17: <unknown function> + 0xfe2ec9 (0x7f0017e9cec9 in venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #18: <unknown function> + 0xfe42ec (0x7f0017e9e2ec in venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #19: <unknown function> + 0x685ce5 (0x7f001753fce5 in venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #20: <unknown function> + 0x685b1d (0x7f001753fb1d in venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #21: <unknown function> + 0x10c7c5c (0x7f0017f81c5c in venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
<omitting python frames>
frame #36: <unknown function> + 0x6e169f (0x7f001759b69f in venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #37: <unknown function> + 0x6dff43 (0x7f0017599f43 in venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #55: <unknown function> + 0x6e169f (0x7f001759b69f in venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #56: <unknown function> + 0x6dff43 (0x7f0017599f43 in venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
Ran it in a fresh virtual environment. Getting a different error this time:
RuntimeError: Unknown device: -3. If you have recently updated the caffe2.proto file to add a new device type, did you forget to update the DeviceTypeName() function to reflect such recent changes?
Full stack trace:
Traceback (most recent call last):
File "new_branch.py", line 37, in <module>
print(jax_loss(array)) # Throws error
^^^^^^^^^^^^^^^
File "venv_fresh/lib64/python3.11/site-packages/torch2jax/api.py", line 245, in wrapped_fn
ret = wrapped_fn_flat(*tree_flatten(args)[0])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "venv_fresh/lib64/python3.11/site-packages/torch2jax/api.py", line 97, in wrapped_fn
return torch_prim.bind(*args)
^^^^^^^^^^^^^^^^^^^^^^
File "venv_fresh/lib64/python3.11/site-packages/jax/_src/core.py", line 422, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "venv_fresh/lib64/python3.11/site-packages/jax/_src/core.py", line 425, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "venv_fresh/lib64/python3.11/site-packages/jax/_src/core.py", line 913, in process_primitive
return primitive.impl(*tracers, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "venv_fresh/lib64/python3.11/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
outs = fun(*args)
^^^^^^^^^^
RuntimeError: Unknown device: -3. If you have recently updated the caffe2.proto file to add a new device type, did you forget to update the DeviceTypeName() function to reflect such recent changes?
Exception raised from DeviceTypeName at ../c10/core/DeviceType.cpp:55 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7f92f8790d87 in venv_fresh/lib64/python3.11/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f92f874175f in venv_fresh/lib64/python3.11/site-packages/torch/lib/libc10.so)
frame #2: c10::DeviceTypeName(c10::DeviceType, bool) + 0x3a9 (0x7f92f874b6d9 in venv_fresh/lib64/python3.11/site-packages/torch/lib/libc10.so)
frame #3: c10::Device::str() const + 0x1f (0x7f92f87498bf in venv_fresh/lib64/python3.11/site-packages/torch/lib/libc10.so)
frame #4: c10::operator<<(std::ostream&, c10::Device const&) + 0x13 (0x7f92f87499e3 in venv_fresh/lib64/python3.11/site-packages/torch/lib/libc10.so)
frame #5: <unknown function> + 0x1e86a2b (0x7f932cd38a2b in venv_fresh/lib64/python3.11/site-packages/torch/lib/libtorch_cpu.so)
frame #6: <unknown function> + 0x116fa6f (0x7f932c021a6f in venv_fresh/lib64/python3.11/site-packages/torch/lib/libtorch_cpu.so)
frame #7: apply_torch_call(void**, DynamicTorchCallDescriptor const&) + 0x21f (0x7f9283032b9f in torch2jax/cpython-311-linux-x86_64/torch2jax_cpp.so)
frame #8: gpu_apply_torch_call(CUstream_st*, void**, char const*, unsigned long) + 0x7a (0x7f9283037fca in torch2jax/cpython-311-linux-x86_64/torch2jax_cpp.so)
frame #9: <unknown function> + 0x45faaff (0x7f934e845aff in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #10: <unknown function> + 0x45fb4f5 (0x7f934e8464f5 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #11: <unknown function> + 0x554a3e3 (0x7f934f7953e3 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #12: <unknown function> + 0x5547b89 (0x7f934f792b89 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #13: <unknown function> + 0x5546244 (0x7f934f791244 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #14: <unknown function> + 0x5545b8f (0x7f934f790b8f in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #15: <unknown function> + 0x75ceea6 (0x7f9351819ea6 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #16: <unknown function> + 0x11f7735 (0x7f934b442735 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #17: <unknown function> + 0x11f8088 (0x7f934b443088 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #18: <unknown function> + 0x11664c5 (0x7f934b3b14c5 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #19: <unknown function> + 0x1168c84 (0x7f934b3b3c84 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #20: <unknown function> + 0x116b1c0 (0x7f934b3b61c0 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #21: <unknown function> + 0x1026341 (0x7f934b271341 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #22: <unknown function> + 0xfe2ec9 (0x7f934b22dec9 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #23: <unknown function> + 0xfe42ec (0x7f934b22f2ec in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #24: <unknown function> + 0x685ce5 (0x7f934a8d0ce5 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #25: <unknown function> + 0x685b1d (0x7f934a8d0b1d in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #26: <unknown function> + 0x10c7c5c (0x7f934b312c5c in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
<omitting python frames>
frame #41: <unknown function> + 0x6e169f (0x7f934a92c69f in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #42: <unknown function> + 0x6dff43 (0x7f934a92af43 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #60: <unknown function> + 0x6e169f (0x7f934a92c69f in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #61: <unknown function> + 0x6dff43 (0x7f934a92af43 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
Possibly related to https://github.com/pytorch/pytorch/issues/121308
Ran this in another fresh venv using GCC 11 this time (the above stack traces were produced for python with GCC 8.5 and pytorch requires at least GCC 9).
Sometimes I get the error
RuntimeError: Unknown layout
Exception raised from operator<< at ../c10/core/Layout.h:69 (most recent call first):
while at other times I get
RuntimeError: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cpu:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
The error alternates randomly between the above two when I rerun the script multiple times.
Here's the full stack trace for the first error:
Traceback (most recent call last):
File "new_branch.py", line 37, in <module>
print(jax_loss(array)) # Throws error
^^^^^^^^^^^^^^^
File "venv_fresh2/lib/python3.11/site-packages/torch2jax/api.py", line 245, in wrapped_fn
ret = wrapped_fn_flat(*tree_flatten(args)[0])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "venv_fresh2/lib/python3.11/site-packages/torch2jax/api.py", line 97, in wrapped_fn
return torch_prim.bind(*args)
^^^^^^^^^^^^^^^^^^^^^^
File "venv_fresh2/lib/python3.11/site-packages/jax/_src/core.py", line 422, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "venv_fresh2/lib/python3.11/site-packages/jax/_src/core.py", line 425, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "venv_fresh2/lib/python3.11/site-packages/jax/_src/core.py", line 913, in process_primitive
return primitive.impl(*tracers, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "venv_fresh2/lib/python3.11/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
outs = fun(*args)
^^^^^^^^^^
RuntimeError: Unknown layout
Exception raised from operator<< at ../c10/core/Layout.h:69 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7f7e4d809d87 in venv_fresh2/lib/python3.11/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0x68 (0x7f7e4d7ba828 in venv_fresh2/lib/python3.11/site-packages/torch/lib/libc10.so)
frame #2: <unknown function> + 0x183589a (0x7f7e8176089a in venv_fresh2/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so)
frame #3: <unknown function> + 0x183b26a (0x7f7e8176626a in venv_fresh2/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so)
frame #4: at::TensorMaker::make_tensor() + 0x66a (0x7f7e822e458a in venv_fresh2/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so)
frame #5: apply_torch_call(void**, DynamicTorchCallDescriptor const&) + 0x21f (0x7f7de00cbb9f in torch2jax/cpython-311-linux-x86_64/torch2jax_cpp.so)
frame #6: gpu_apply_torch_call(CUstream_st*, void**, char const*, unsigned long) + 0x7a (0x7f7de00d0fca in torch2jax/cpython-311-linux-x86_64/torch2jax_cpp.so)
frame #7: <unknown function> + 0x45faaff (0x7f7ea3920aff in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #8: <unknown function> + 0x45fb4f5 (0x7f7ea39214f5 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #9: <unknown function> + 0x554a3e3 (0x7f7ea48703e3 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #10: <unknown function> + 0x5547b89 (0x7f7ea486db89 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #11: <unknown function> + 0x5546244 (0x7f7ea486c244 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #12: <unknown function> + 0x5545b8f (0x7f7ea486bb8f in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #13: <unknown function> + 0x75ceea6 (0x7f7ea68f4ea6 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #14: <unknown function> + 0x11f7735 (0x7f7ea051d735 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #15: <unknown function> + 0x11f8088 (0x7f7ea051e088 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #16: <unknown function> + 0x11664c5 (0x7f7ea048c4c5 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #17: <unknown function> + 0x1168c84 (0x7f7ea048ec84 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #18: <unknown function> + 0x116b1c0 (0x7f7ea04911c0 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #19: <unknown function> + 0x1026341 (0x7f7ea034c341 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #20: <unknown function> + 0xfe2ec9 (0x7f7ea0308ec9 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #21: <unknown function> + 0xfe42ec (0x7f7ea030a2ec in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #22: <unknown function> + 0x685ce5 (0x7f7e9f9abce5 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #23: <unknown function> + 0x685b1d (0x7f7e9f9abb1d in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #24: <unknown function> + 0x10c7c5c (0x7f7ea03edc5c in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
<omitting python frames>
frame #40: <unknown function> + 0x6e169f (0x7f7e9fa0769f in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #41: <unknown function> + 0x6dff43 (0x7f7e9fa05f43 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #56: <unknown function> + 0x6e169f (0x7f7e9fa0769f in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #57: <unknown function> + 0x6dff43 (0x7f7e9fa05f43 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
Wow, thanks a lot for the quick follow-up! I will test the new version tomorrow.
Just for your information, what I'm ultimately trying to do is calling this with
jax.pmap
:# Call `torch2jax` once. Then split `inputs` over GPUs and combine results. jax.pmap(jax_loss)(inputs)
Not sure if I should expect this to work since the external callback is "stateful", i.e. the neural net has some parameters in it which have to be copied onto each device.
At the moment, my plan is to make multiple copies of the neural net on each device "manually" (i.e. keep a list of functions, each compiled with
torch2jax
) and then select a function depending on which device the input is located.Do you think just using
pmap
would work automatically, or would I have to do it the above manual way?
pmap
might be hard to support automatically because it'd need to somehow be able to communicate with PyTorch. However, your PyTorch manual logic sounds like a very good solution. You can either define multiple torch2jax
functions or probably do the model case choice via Python logic directly inside the wrapped function.
Currently, torch2jax
makes two assumptions about input and output tensors:
I expect PyTorch might break 1. sometimes, but you generally should be able to use torch.contiguous on the output tensors in your pytorch function to fix it.
I expect JAX might sometimes break 2. especially when trying to use something like pmap
. You'd need to make sure input tensors are not sharded. I'm not yet sure what the fix there is, but I'm going to look into soon.
If you have a small testing script, I'd love to take a look!
The above errors were still produced with the original testing script (no modifications at all) with the new branch.
- they are NOT placed on a single device (they are sharded tensors with multiple devices)
But the script does not shard the array because it's on gpu 1 and it is working for you if I understand correctly?
Yes, it works for me. To give you a bit of detail, I'm using 2 x 4090 setup, CUDA Version: 12.3, Python 3.11.
Can you just confirm that you're recompiling the cpp extension by deleting ~/.cache/torch2jax
if you're on UNIX?
Ah, I wasn't doing that. However, I'm getting the following error now:
RuntimeError: Error building extension 'torch2jax_cpp'
...
FAILED: gpu_impl.cuda.o
...
/usr/local/software/cuda/12.1/include/crt/sm_80_rt.h(109): error: more than one instance of overloaded function "__nv_associate_access_property_impl" has "C" linkage
(On that note, would it be possible to expose a force_recompile
option in the top-level torch2jax
function as I didn't know the cached files were located there in the first place?)
Ok, thanks for so much feedback!
Your CUDA error is likely due to my incorrect placement of some cuda utility code, I have fixed this in the latest commit on the multi_gpu_experimental branch.
I added explicit extension versioning to cpp extension, so that a new torch2jax
will automatically trigger recompilation.
Explicit recompilation is currently supported, but somewhat hidden because the user should probably never have to do it (especially when versioning of the cpp extension works, something I just added), but you can do it from python:
from torch2jax.compile import compile_extension
compile_extension(force_recompile=True)
Let me know if the new extension CUDA code compiles without an error. I tried testing it on my setup, but couldn't reproduce it (even though your setup I think correctly flags it as the wrong placement of CUDA code).
Thanks for introducing the cpp versioning!
Unfortunately the error still persists. Here's a full stack trace:
/venv_fresh2/lib/python3.11/site-packages/torch/utils/_device.py:77: UserWarning: Using a target size (torch.Size([1, 10])) that is different to the input size (torch.Size([10])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
return func(*args, **kwargs)
Cache empty, we will compile the C++ extension component now...
Traceback (most recent call last):
File venv_fresh2/lib/python3.11/site-packages/torch2jax/compile.py", line 52, in compile_extension
mod = import_module("torch2jax_cpp")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/software/spack/spack-views/rocky8-a100-20221118/python-3.11.9/gcc-11.3.0/57ayfwmekjj72xepaojw7tke676c7sqd/lib/python3.11/importlib/__init__.py", line 126, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
File "<frozen importlib._bootstrap>", line 1140, in _find_and_load_unlocked
ModuleNotFoundError: No module named 'torch2jax_cpp'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File venv_fresh2/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2096, in _run_ninja_build
subprocess.run(
File "/usr/local/software/spack/spack-views/rocky8-a100-20221118/python-3.11.9/gcc-11.3.0/57ayfwmekjj72xepaojw7tke676c7sqd/lib/python3.11/subprocess.py", line 571, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File new_branch.py", line 31, in <module>
jax_loss = torch2jax(
^^^^^^^^^^
File venv_fresh2/lib/python3.11/site-packages/torch2jax/api.py", line 232, in torch2jax
wrapped_fn_flat = torch2jax_flat(
^^^^^^^^^^^^^^^
File venv_fresh2/lib/python3.11/site-packages/torch2jax/api.py", line 44, in torch2jax_flat
cpp_module = compile_and_import_module()
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File venv_fresh2/lib/python3.11/site-packages/torch2jax/compile.py", line 93, in compile_and_import_module
compile_extension(force_recompile)
File venv_fresh2/lib/python3.11/site-packages/torch2jax/compile.py", line 65, in compile_extension
mod = cpp_extension.load(
^^^^^^^^^^^^^^^^^^^
File venv_fresh2/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 1306, in load
return _jit_compile(
^^^^^^^^^^^^^
File venv_fresh2/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 1710, in _jit_compile
_write_ninja_file_and_build_library(
File venv_fresh2/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 1823, in _write_ninja_file_and_build_library
_run_ninja_build(
File venv_fresh2/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2112, in _run_ninja_build
raise RuntimeError(message) from e
RuntimeError: Error building extension 'torch2jax_cpp': [1/3] /usr/local/software/spack/spack-rhel8-20210927/opt/spack/linux-centos8-zen2/gcc-9.4.0/cuda-11.4.0-3hnxhjt2jt4ruy75w2q4mnvkw7dty72l/bin/nvcc --generate-dependencies-with-compile --dependency-output main.cuda.o.d -ccbin /usr/local/software/spack/spack-views/rocky8-a100-20221118/gcc-11.3.0/gcc-11.3.0/i4xnp7h53ty3rosv2mjoycl2d6cyjddv/bin/gcc -DTORCH_EXTENSION_NAME=torch2jax_cpp -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include/TH -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include/THC -isystem /usr/local/software/spack/spack-rhel8-20210927/opt/spack/linux-centos8-zen2/gcc-9.4.0/cuda-11.4.0-3hnxhjt2jt4ruy75w2q4mnvkw7dty72l/include -isystem /usr/local/software/spack/spack-views/rocky8-a100-20221118/python-3.11.9/gcc-11.3.0/57ayfwmekjj72xepaojw7tke676c7sqd/include/python3.11 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 --compiler-options '-fPIC' -DTORCH2JAX_WITH_CUDA -O3 -std=c++17 -c /venv_fresh2/lib/python3.11/site-packages/torch2jax/cpp/main.cu -o main.cuda.o
FAILED: main.cuda.o
/usr/local/software/spack/spack-rhel8-20210927/opt/spack/linux-centos8-zen2/gcc-9.4.0/cuda-11.4.0-3hnxhjt2jt4ruy75w2q4mnvkw7dty72l/bin/nvcc --generate-dependencies-with-compile --dependency-output main.cuda.o.d -ccbin /usr/local/software/spack/spack-views/rocky8-a100-20221118/gcc-11.3.0/gcc-11.3.0/i4xnp7h53ty3rosv2mjoycl2d6cyjddv/bin/gcc -DTORCH_EXTENSION_NAME=torch2jax_cpp -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include/TH -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include/THC -isystem /usr/local/software/spack/spack-rhel8-20210927/opt/spack/linux-centos8-zen2/gcc-9.4.0/cuda-11.4.0-3hnxhjt2jt4ruy75w2q4mnvkw7dty72l/include -isystem /usr/local/software/spack/spack-views/rocky8-a100-20221118/python-3.11.9/gcc-11.3.0/57ayfwmekjj72xepaojw7tke676c7sqd/include/python3.11 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 --compiler-options '-fPIC' -DTORCH2JAX_WITH_CUDA -O3 -std=c++17 -c /venv_fresh2/lib/python3.11/site-packages/torch2jax/cpp/main.cu -o main.cuda.o
/usr/local/software/cuda/12.1/include/crt/sm_80_rt.h(109): error: more than one instance of overloaded function "__nv_associate_access_property_impl" has "C" linkage
1 error detected in the compilation of venv_fresh2/lib/python3.11/site-packages/torch2jax/cpp/main.cu".
[2/3] /usr/local/software/spack/spack-rhel8-20210927/opt/spack/linux-centos8-zen2/gcc-9.4.0/cuda-11.4.0-3hnxhjt2jt4ruy75w2q4mnvkw7dty72l/bin/nvcc --generate-dependencies-with-compile --dependency-output gpu_impl.cuda.o.d -ccbin /usr/local/software/spack/spack-views/rocky8-a100-20221118/gcc-11.3.0/gcc-11.3.0/i4xnp7h53ty3rosv2mjoycl2d6cyjddv/bin/gcc -DTORCH_EXTENSION_NAME=torch2jax_cpp -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include/TH -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include/THC -isystem /usr/local/software/spack/spack-rhel8-20210927/opt/spack/linux-centos8-zen2/gcc-9.4.0/cuda-11.4.0-3hnxhjt2jt4ruy75w2q4mnvkw7dty72l/include -isystem /usr/local/software/spack/spack-views/rocky8-a100-20221118/python-3.11.9/gcc-11.3.0/57ayfwmekjj72xepaojw7tke676c7sqd/include/python3.11 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 --compiler-options '-fPIC' -DTORCH2JAX_WITH_CUDA -O3 -std=c++17 -c /venv_fresh2/lib/python3.11/site-packages/torch2jax/cpp/gpu_impl.cu -o gpu_impl.cuda.o
FAILED: gpu_impl.cuda.o
/usr/local/software/spack/spack-rhel8-20210927/opt/spack/linux-centos8-zen2/gcc-9.4.0/cuda-11.4.0-3hnxhjt2jt4ruy75w2q4mnvkw7dty72l/bin/nvcc --generate-dependencies-with-compile --dependency-output gpu_impl.cuda.o.d -ccbin /usr/local/software/spack/spack-views/rocky8-a100-20221118/gcc-11.3.0/gcc-11.3.0/i4xnp7h53ty3rosv2mjoycl2d6cyjddv/bin/gcc -DTORCH_EXTENSION_NAME=torch2jax_cpp -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include/TH -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include/THC -isystem /usr/local/software/spack/spack-rhel8-20210927/opt/spack/linux-centos8-zen2/gcc-9.4.0/cuda-11.4.0-3hnxhjt2jt4ruy75w2q4mnvkw7dty72l/include -isystem /usr/local/software/spack/spack-views/rocky8-a100-20221118/python-3.11.9/gcc-11.3.0/57ayfwmekjj72xepaojw7tke676c7sqd/include/python3.11 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 --compiler-options '-fPIC' -DTORCH2JAX_WITH_CUDA -O3 -std=c++17 -c /venv_fresh2/lib/python3.11/site-packages/torch2jax/cpp/gpu_impl.cu -o gpu_impl.cuda.o
/usr/local/software/cuda/12.1/include/crt/sm_80_rt.h(109): error: more than one instance of overloaded function "__nv_associate_access_property_impl" has "C" linkage
1 error detected in the compilation of venv_fresh2/lib/python3.11/site-packages/torch2jax/cpp/gpu_impl.cu".
ninja: build stopped: subcommand failed.
From the trace, it might be that your pytorch is finding the wrong cuda compiler on your system: CUDA 11.4 instead of the 12.1
In those cases PyTorch advises you set the environment variable CUDA_HOME
explicitly, either via export CUDA_HOME=...
or you can do it from python via import os; os.environ["CUDA_HOME"] = "..."
before any other imports.
In your case, the correct CUDA home is probably export CUDA_HOME="/usr/local/software/cuda/12.1"
Let me know if the compilation succeeds with that change, if it's not that I'll try to dig into it deeper sometime soon.
Trying export CUDA_HOME="/usr/local/software/cuda/12.1"
(thanks for giving me the directory) changed the error messages so this is likely at least part of the problem. The new error messages state that gcc versions > 10 are unsupported (gcc 11.2 is installed).
It seems quite likely that there's some dependency problem between CUDA, GCC, python and pytorch on the machine I'm using, especially given that it's working on your machine. As I don't have enough privileges to start trying different installations, it seems to be a bit of a dead end here.
Perhaps you could merge at least the cpp versioning and the with torch.no_grad()
into the main branch.
In any case, the "temporary" solution of moving things between GPUs on the pytorch side works for a multi-gpu setup.
Hmmm, that's unfortunate you're still getting compilation errors. I wish I could offer more help, but in torch2jax
I'm almost entirely deferring to the PyTorch extension compilation mechanism. Perhaps they can introduce more control over CUDA loading soon.
I'm glad the temporary solution is working, I'll keep thinking of how JAX sharding/device mapping could be used to inform PyTorch calling.
I'll merge all the code changes your feedback and suggestions contributed later today. I was hoping to confirm you can successfully compile on your side.
I'll keep this issue open until tomorrow unless you have any other questions. If anything else comes up later feel free to comment here or open another issue!
Ok, I'll run some final tests and let you know if things are working hopefully later today.
I merged the changes your feedback contributed into main
. Let me know if you are still experiencing issues.
I published a new PyPI version of this package here that includes those changes: https://pypi.org/project/wrap-torch2jax/
Closing, the first version of multi-gpu support should work!
Hi, I'm getting an error when trying to run
torch2jax
on multiple GPUs. For some reason a part of the model ends up on gpu 0 after usingtorch2jax
despite being moved to gpu 1.Here's a reproducer using 2 x A100 GPUs:
Error:
Setup: