rdyro / torch2jax

Wraps PyTorch code in a JIT-compatible way for JAX. Supports automatically defining gradients for reverse-mode AutoDiff.
https://rdyro.github.io/torch2jax/
MIT License
37 stars 1 forks source link

Multi-gpu support #12

Closed zwei-beiner closed 6 months ago

zwei-beiner commented 6 months ago

Hi, I'm getting an error when trying to run torch2jax on multiple GPUs. For some reason a part of the model ends up on gpu 0 after using torch2jax despite being moved to gpu 1.

Here's a reproducer using 2 x A100 GPUs:

import jax
import jax.numpy as jnp
import torch
import torch.nn as nn
from torch2jax import torch2jax

torch.set_default_device('cuda:1') # Move everything to gpu 1

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Linear(32, 10)
    def forward(self, x):
        x = self.fc(x)
        return x

net = Net().to('cuda:1') # Move explicitly to gpu 1
def loss(input):
    output = net(input)
    target = torch.zeros(10) 
    target = target.view(1, -1) 
    criterion = nn.MSELoss()
    loss = criterion(output, target)
    return loss

input = torch.ones(32).to('cuda:1') # Input on gpu 1
result = loss(input) # Works
assert result.device.index == 1 # Output is on gpu 1

# Try the above but in jax
jax_loss = torch2jax(
    loss,
    input
)
array = jax.device_put(jnp.ones((32)), jax.devices()[1]) # jax array on gpu 1
print(array.devices()) # Prints {cuda(id=1)}
print(jax_loss(array)) # Throws error

Error:

Traceback (most recent call last):
  File "bug.py", line 36, in <module>
    print(jax_loss(array)) # Throws an error
          ^^^^^^^^^^^^^^^
  File "venv/lib64/python3.11/site-packages/torch2jax/api.py", line 245, in wrapped_fn
    ret = wrapped_fn_flat(*tree_flatten(args)[0])
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv/lib64/python3.11/site-packages/torch2jax/api.py", line 97, in wrapped_fn
    return torch_prim.bind(*args)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "venv/lib64/python3.11/site-packages/jax/_src/core.py", line 422, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv/lib64/python3.11/site-packages/jax/_src/core.py", line 425, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv/lib64/python3.11/site-packages/jax/_src/core.py", line 913, in process_primitive
    return primitive.impl(*tracers, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv/lib64/python3.11/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
    outs = fun(*args)
           ^^^^^^^^^^
RuntimeError: Specified device cuda:0 does not match device of data cuda:1

Setup:

Python 3.11.5
torch 2.0.1
jax 0.4.26
jaxlib 0.4.26+cuda12.cudnn89
rdyro commented 6 months ago

Thanks for catching this!

The package was never tested on a multi-GPU setup unfortunately, it's a current TODO.

Thanks for contributing this minimal example, I should be able to make at least this simple test case work soon.

Not sure if that's helpful, but you could try a temporary workaround pytorch wrapper function that takes all inputs on GPU0 and returns outputs on GPU0 as well (even if computation is done elsewhere).

rdyro commented 6 months ago

@zwei-beiner Can you install the version of torch2jax from this branch: https://github.com/rdyro/torch2jax/tree/multi_gpu_experimental

I was able to make your script run with those changes to torch2jax, but please let me know if you run into any trouble.

zwei-beiner commented 6 months ago

Wow, thanks a lot for the quick follow-up! I will test the new version tomorrow.

Just for your information, what I'm ultimately trying to do is calling this with jax.pmap:

# Call `torch2jax` once. Then split `inputs` over GPUs and combine results.
jax.pmap(jax_loss)(inputs)

Not sure if I should expect this to work since the external callback is "stateful", i.e. the neural net has some parameters in it which have to be copied onto each device.

At the moment, my plan is to make multiple copies of the neural net on each device "manually" (i.e. keep a list of functions, each compiled with torch2jax) and then select a function depending on which device the input is located.

Do you think just using pmap would work automatically, or would I have to do it the above manual way?

zwei-beiner commented 6 months ago

Just ran the script with the new branch and I am still getting the error. Here's the full stack trace:

Traceback (most recent call last):
  File "new_branch.py", line 37, in <module>
    print(jax_loss(array)) # Throws error
          ^^^^^^^^^^^^^^^
  File " venv/lib64/python3.11/site-packages/torch2jax/api.py", line 245, in wrapped_fn
    ret = wrapped_fn_flat(*tree_flatten(args)[0])
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File " venv/lib64/python3.11/site-packages/torch2jax/api.py", line 97, in wrapped_fn
    return torch_prim.bind(*args)
           ^^^^^^^^^^^^^^^^^^^^^^
  File " venv/lib64/python3.11/site-packages/jax/_src/core.py", line 422, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File " venv/lib64/python3.11/site-packages/jax/_src/core.py", line 425, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File " venv/lib64/python3.11/site-packages/jax/_src/core.py", line 913, in process_primitive
    return primitive.impl(*tracers, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File " venv/lib64/python3.11/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
    outs = fun(*args)
           ^^^^^^^^^^
RuntimeError: Specified device cuda:0 does not match device of data cuda:1
Exception raised from make_tensor at aten/src/ATen/Functions.cpp:24 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7eff99cb54d7 in  venv/lib64/python3.11/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x11a22dc (0x7effc0fe32dc in  venv/lib64/python3.11/site-packages/torch/lib/libtorch_cpu.so)
frame #2: apply_torch_call(void**, DynamicTorchCallDescriptor const&) + 0x21f (0x7eff675ebb9f in torch2jax/cpython-311-linux-x86_64/torch2jax_cpp.so)
frame #3: gpu_apply_torch_call(CUstream_st*, void**, char const*, unsigned long) + 0x7a (0x7eff675f0fca in torch2jax/cpython-311-linux-x86_64/torch2jax_cpp.so)
frame #4: <unknown function> + 0x45faaff (0x7f001b4b4aff in  venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #5: <unknown function> + 0x45fb4f5 (0x7f001b4b54f5 in  venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #6: <unknown function> + 0x554a3e3 (0x7f001c4043e3 in  venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #7: <unknown function> + 0x5547b89 (0x7f001c401b89 in  venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #8: <unknown function> + 0x5546244 (0x7f001c400244 in  venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #9: <unknown function> + 0x5545b8f (0x7f001c3ffb8f in  venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #10: <unknown function> + 0x75ceea6 (0x7f001e488ea6 in  venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #11: <unknown function> + 0x11f7735 (0x7f00180b1735 in  venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #12: <unknown function> + 0x11f8088 (0x7f00180b2088 in  venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #13: <unknown function> + 0x11664c5 (0x7f00180204c5 in  venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #14: <unknown function> + 0x1168c84 (0x7f0018022c84 in  venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #15: <unknown function> + 0x116b1c0 (0x7f00180251c0 in  venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #16: <unknown function> + 0x1026341 (0x7f0017ee0341 in  venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #17: <unknown function> + 0xfe2ec9 (0x7f0017e9cec9 in  venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #18: <unknown function> + 0xfe42ec (0x7f0017e9e2ec in  venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #19: <unknown function> + 0x685ce5 (0x7f001753fce5 in  venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #20: <unknown function> + 0x685b1d (0x7f001753fb1d in  venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #21: <unknown function> + 0x10c7c5c (0x7f0017f81c5c in  venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
<omitting python frames>
frame #36: <unknown function> + 0x6e169f (0x7f001759b69f in  venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #37: <unknown function> + 0x6dff43 (0x7f0017599f43 in  venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #55: <unknown function> + 0x6e169f (0x7f001759b69f in  venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #56: <unknown function> + 0x6dff43 (0x7f0017599f43 in  venv/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
zwei-beiner commented 6 months ago

Ran it in a fresh virtual environment. Getting a different error this time:

RuntimeError: Unknown device: -3. If you have recently updated the caffe2.proto file to add a new device type, did you forget to update the DeviceTypeName() function to reflect such recent changes?

Full stack trace:

Traceback (most recent call last):
  File "new_branch.py", line 37, in <module>
    print(jax_loss(array)) # Throws error
          ^^^^^^^^^^^^^^^
  File "venv_fresh/lib64/python3.11/site-packages/torch2jax/api.py", line 245, in wrapped_fn
    ret = wrapped_fn_flat(*tree_flatten(args)[0])
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv_fresh/lib64/python3.11/site-packages/torch2jax/api.py", line 97, in wrapped_fn
    return torch_prim.bind(*args)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "venv_fresh/lib64/python3.11/site-packages/jax/_src/core.py", line 422, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv_fresh/lib64/python3.11/site-packages/jax/_src/core.py", line 425, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv_fresh/lib64/python3.11/site-packages/jax/_src/core.py", line 913, in process_primitive
    return primitive.impl(*tracers, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv_fresh/lib64/python3.11/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
    outs = fun(*args)
           ^^^^^^^^^^
RuntimeError: Unknown device: -3. If you have recently updated the caffe2.proto file to add a new device type, did you forget to update the DeviceTypeName() function to reflect such recent changes?
Exception raised from DeviceTypeName at ../c10/core/DeviceType.cpp:55 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7f92f8790d87 in venv_fresh/lib64/python3.11/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f92f874175f in venv_fresh/lib64/python3.11/site-packages/torch/lib/libc10.so)
frame #2: c10::DeviceTypeName(c10::DeviceType, bool) + 0x3a9 (0x7f92f874b6d9 in venv_fresh/lib64/python3.11/site-packages/torch/lib/libc10.so)
frame #3: c10::Device::str() const + 0x1f (0x7f92f87498bf in venv_fresh/lib64/python3.11/site-packages/torch/lib/libc10.so)
frame #4: c10::operator<<(std::ostream&, c10::Device const&) + 0x13 (0x7f92f87499e3 in venv_fresh/lib64/python3.11/site-packages/torch/lib/libc10.so)
frame #5: <unknown function> + 0x1e86a2b (0x7f932cd38a2b in venv_fresh/lib64/python3.11/site-packages/torch/lib/libtorch_cpu.so)
frame #6: <unknown function> + 0x116fa6f (0x7f932c021a6f in venv_fresh/lib64/python3.11/site-packages/torch/lib/libtorch_cpu.so)
frame #7: apply_torch_call(void**, DynamicTorchCallDescriptor const&) + 0x21f (0x7f9283032b9f in torch2jax/cpython-311-linux-x86_64/torch2jax_cpp.so)
frame #8: gpu_apply_torch_call(CUstream_st*, void**, char const*, unsigned long) + 0x7a (0x7f9283037fca in torch2jax/cpython-311-linux-x86_64/torch2jax_cpp.so)
frame #9: <unknown function> + 0x45faaff (0x7f934e845aff in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #10: <unknown function> + 0x45fb4f5 (0x7f934e8464f5 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #11: <unknown function> + 0x554a3e3 (0x7f934f7953e3 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #12: <unknown function> + 0x5547b89 (0x7f934f792b89 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #13: <unknown function> + 0x5546244 (0x7f934f791244 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #14: <unknown function> + 0x5545b8f (0x7f934f790b8f in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #15: <unknown function> + 0x75ceea6 (0x7f9351819ea6 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #16: <unknown function> + 0x11f7735 (0x7f934b442735 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #17: <unknown function> + 0x11f8088 (0x7f934b443088 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #18: <unknown function> + 0x11664c5 (0x7f934b3b14c5 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #19: <unknown function> + 0x1168c84 (0x7f934b3b3c84 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #20: <unknown function> + 0x116b1c0 (0x7f934b3b61c0 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #21: <unknown function> + 0x1026341 (0x7f934b271341 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #22: <unknown function> + 0xfe2ec9 (0x7f934b22dec9 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #23: <unknown function> + 0xfe42ec (0x7f934b22f2ec in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #24: <unknown function> + 0x685ce5 (0x7f934a8d0ce5 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #25: <unknown function> + 0x685b1d (0x7f934a8d0b1d in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #26: <unknown function> + 0x10c7c5c (0x7f934b312c5c in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
<omitting python frames>
frame #41: <unknown function> + 0x6e169f (0x7f934a92c69f in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #42: <unknown function> + 0x6dff43 (0x7f934a92af43 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #60: <unknown function> + 0x6e169f (0x7f934a92c69f in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
frame #61: <unknown function> + 0x6dff43 (0x7f934a92af43 in venv_fresh/lib64/python3.11/site-packages/jaxlib/xla_extension.so)
zwei-beiner commented 6 months ago

Possibly related to https://github.com/pytorch/pytorch/issues/121308

zwei-beiner commented 6 months ago

Ran this in another fresh venv using GCC 11 this time (the above stack traces were produced for python with GCC 8.5 and pytorch requires at least GCC 9).

Sometimes I get the error

RuntimeError: Unknown layout
Exception raised from operator<< at ../c10/core/Layout.h:69 (most recent call first):

while at other times I get

RuntimeError: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cpu:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

The error alternates randomly between the above two when I rerun the script multiple times.

Here's the full stack trace for the first error:

Traceback (most recent call last):
  File "new_branch.py", line 37, in <module>
    print(jax_loss(array)) # Throws error
          ^^^^^^^^^^^^^^^
  File "venv_fresh2/lib/python3.11/site-packages/torch2jax/api.py", line 245, in wrapped_fn
    ret = wrapped_fn_flat(*tree_flatten(args)[0])
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv_fresh2/lib/python3.11/site-packages/torch2jax/api.py", line 97, in wrapped_fn
    return torch_prim.bind(*args)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "venv_fresh2/lib/python3.11/site-packages/jax/_src/core.py", line 422, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv_fresh2/lib/python3.11/site-packages/jax/_src/core.py", line 425, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv_fresh2/lib/python3.11/site-packages/jax/_src/core.py", line 913, in process_primitive
    return primitive.impl(*tracers, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "venv_fresh2/lib/python3.11/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
    outs = fun(*args)
           ^^^^^^^^^^
RuntimeError: Unknown layout
Exception raised from operator<< at ../c10/core/Layout.h:69 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7f7e4d809d87 in venv_fresh2/lib/python3.11/site-packages/torch/lib/libc10.so)                                                                                                                                                                            
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0x68 (0x7f7e4d7ba828 in venv_fresh2/lib/python3.11/site-packages/torch/lib/libc10.so)                                                                                                                                               
frame #2: <unknown function> + 0x183589a (0x7f7e8176089a in venv_fresh2/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so)
frame #3: <unknown function> + 0x183b26a (0x7f7e8176626a in venv_fresh2/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so)
frame #4: at::TensorMaker::make_tensor() + 0x66a (0x7f7e822e458a in venv_fresh2/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so)
frame #5: apply_torch_call(void**, DynamicTorchCallDescriptor const&) + 0x21f (0x7f7de00cbb9f in torch2jax/cpython-311-linux-x86_64/torch2jax_cpp.so)
frame #6: gpu_apply_torch_call(CUstream_st*, void**, char const*, unsigned long) + 0x7a (0x7f7de00d0fca in torch2jax/cpython-311-linux-x86_64/torch2jax_cpp.so)
frame #7: <unknown function> + 0x45faaff (0x7f7ea3920aff in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #8: <unknown function> + 0x45fb4f5 (0x7f7ea39214f5 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #9: <unknown function> + 0x554a3e3 (0x7f7ea48703e3 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #10: <unknown function> + 0x5547b89 (0x7f7ea486db89 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #11: <unknown function> + 0x5546244 (0x7f7ea486c244 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #12: <unknown function> + 0x5545b8f (0x7f7ea486bb8f in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #13: <unknown function> + 0x75ceea6 (0x7f7ea68f4ea6 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #14: <unknown function> + 0x11f7735 (0x7f7ea051d735 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #15: <unknown function> + 0x11f8088 (0x7f7ea051e088 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #16: <unknown function> + 0x11664c5 (0x7f7ea048c4c5 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #17: <unknown function> + 0x1168c84 (0x7f7ea048ec84 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #18: <unknown function> + 0x116b1c0 (0x7f7ea04911c0 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #19: <unknown function> + 0x1026341 (0x7f7ea034c341 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #20: <unknown function> + 0xfe2ec9 (0x7f7ea0308ec9 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #21: <unknown function> + 0xfe42ec (0x7f7ea030a2ec in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #22: <unknown function> + 0x685ce5 (0x7f7e9f9abce5 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #23: <unknown function> + 0x685b1d (0x7f7e9f9abb1d in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #24: <unknown function> + 0x10c7c5c (0x7f7ea03edc5c in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
<omitting python frames>
frame #40: <unknown function> + 0x6e169f (0x7f7e9fa0769f in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #41: <unknown function> + 0x6dff43 (0x7f7e9fa05f43 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #56: <unknown function> + 0x6e169f (0x7f7e9fa0769f in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
frame #57: <unknown function> + 0x6dff43 (0x7f7e9fa05f43 in venv_fresh2/lib/python3.11/site-packages/jaxlib/xla_extension.so)
rdyro commented 6 months ago

Wow, thanks a lot for the quick follow-up! I will test the new version tomorrow.

Just for your information, what I'm ultimately trying to do is calling this with jax.pmap:

# Call `torch2jax` once. Then split `inputs` over GPUs and combine results.
jax.pmap(jax_loss)(inputs)

Not sure if I should expect this to work since the external callback is "stateful", i.e. the neural net has some parameters in it which have to be copied onto each device.

At the moment, my plan is to make multiple copies of the neural net on each device "manually" (i.e. keep a list of functions, each compiled with torch2jax) and then select a function depending on which device the input is located.

Do you think just using pmap would work automatically, or would I have to do it the above manual way?

pmap might be hard to support automatically because it'd need to somehow be able to communicate with PyTorch. However, your PyTorch manual logic sounds like a very good solution. You can either define multiple torch2jax functions or probably do the model case choice via Python logic directly inside the wrapped function.

rdyro commented 6 months ago

Currently, torch2jax makes two assumptions about input and output tensors:

  1. they are contiguously laid out in memory
  2. they are NOT placed on a single device (they are sharded tensors with multiple devices)

I expect PyTorch might break 1. sometimes, but you generally should be able to use torch.contiguous on the output tensors in your pytorch function to fix it.

I expect JAX might sometimes break 2. especially when trying to use something like pmap. You'd need to make sure input tensors are not sharded. I'm not yet sure what the fix there is, but I'm going to look into soon.

rdyro commented 6 months ago

If you have a small testing script, I'd love to take a look!

zwei-beiner commented 6 months ago

The above errors were still produced with the original testing script (no modifications at all) with the new branch.

  1. they are NOT placed on a single device (they are sharded tensors with multiple devices)

But the script does not shard the array because it's on gpu 1 and it is working for you if I understand correctly?

rdyro commented 6 months ago

Yes, it works for me. To give you a bit of detail, I'm using 2 x 4090 setup, CUDA Version: 12.3, Python 3.11.

Can you just confirm that you're recompiling the cpp extension by deleting ~/.cache/torch2jax if you're on UNIX?

zwei-beiner commented 6 months ago

Ah, I wasn't doing that. However, I'm getting the following error now:

RuntimeError: Error building extension 'torch2jax_cpp'
...
FAILED: gpu_impl.cuda.o 
...
/usr/local/software/cuda/12.1/include/crt/sm_80_rt.h(109): error: more than one instance of overloaded function "__nv_associate_access_property_impl" has "C" linkage
zwei-beiner commented 6 months ago

(On that note, would it be possible to expose a force_recompile option in the top-level torch2jax function as I didn't know the cached files were located there in the first place?)

rdyro commented 6 months ago

Ok, thanks for so much feedback!

Your CUDA error is likely due to my incorrect placement of some cuda utility code, I have fixed this in the latest commit on the multi_gpu_experimental branch.

I added explicit extension versioning to cpp extension, so that a new torch2jax will automatically trigger recompilation.

Explicit recompilation is currently supported, but somewhat hidden because the user should probably never have to do it (especially when versioning of the cpp extension works, something I just added), but you can do it from python:

from torch2jax.compile import compile_extension
compile_extension(force_recompile=True)

Let me know if the new extension CUDA code compiles without an error. I tried testing it on my setup, but couldn't reproduce it (even though your setup I think correctly flags it as the wrong placement of CUDA code).

zwei-beiner commented 6 months ago

Thanks for introducing the cpp versioning!

Unfortunately the error still persists. Here's a full stack trace:

/venv_fresh2/lib/python3.11/site-packages/torch/utils/_device.py:77: UserWarning: Using a target size (torch.Size([1, 10])) that is different to the input size (torch.Size([10])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  return func(*args, **kwargs)
Cache empty, we will compile the C++ extension component now...
Traceback (most recent call last):
  File venv_fresh2/lib/python3.11/site-packages/torch2jax/compile.py", line 52, in compile_extension
    mod = import_module("torch2jax_cpp")
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/software/spack/spack-views/rocky8-a100-20221118/python-3.11.9/gcc-11.3.0/57ayfwmekjj72xepaojw7tke676c7sqd/lib/python3.11/importlib/__init__.py", line 126, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
  File "<frozen importlib._bootstrap>", line 1140, in _find_and_load_unlocked
ModuleNotFoundError: No module named 'torch2jax_cpp'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File venv_fresh2/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2096, in _run_ninja_build
    subprocess.run(
  File "/usr/local/software/spack/spack-views/rocky8-a100-20221118/python-3.11.9/gcc-11.3.0/57ayfwmekjj72xepaojw7tke676c7sqd/lib/python3.11/subprocess.py", line 571, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File new_branch.py", line 31, in <module>
    jax_loss = torch2jax(
               ^^^^^^^^^^
  File venv_fresh2/lib/python3.11/site-packages/torch2jax/api.py", line 232, in torch2jax
    wrapped_fn_flat = torch2jax_flat(
                      ^^^^^^^^^^^^^^^
  File venv_fresh2/lib/python3.11/site-packages/torch2jax/api.py", line 44, in torch2jax_flat
    cpp_module = compile_and_import_module()
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File venv_fresh2/lib/python3.11/site-packages/torch2jax/compile.py", line 93, in compile_and_import_module
    compile_extension(force_recompile)
  File venv_fresh2/lib/python3.11/site-packages/torch2jax/compile.py", line 65, in compile_extension
    mod = cpp_extension.load(
          ^^^^^^^^^^^^^^^^^^^
  File venv_fresh2/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 1306, in load
    return _jit_compile(
           ^^^^^^^^^^^^^
  File venv_fresh2/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 1710, in _jit_compile
    _write_ninja_file_and_build_library(
  File venv_fresh2/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 1823, in _write_ninja_file_and_build_library
    _run_ninja_build(
  File venv_fresh2/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2112, in _run_ninja_build
    raise RuntimeError(message) from e
RuntimeError: Error building extension 'torch2jax_cpp': [1/3] /usr/local/software/spack/spack-rhel8-20210927/opt/spack/linux-centos8-zen2/gcc-9.4.0/cuda-11.4.0-3hnxhjt2jt4ruy75w2q4mnvkw7dty72l/bin/nvcc --generate-dependencies-with-compile --dependency-output main.cuda.o.d -ccbin /usr/local/software/spack/spack-views/rocky8-a100-20221118/gcc-11.3.0/gcc-11.3.0/i4xnp7h53ty3rosv2mjoycl2d6cyjddv/bin/gcc -DTORCH_EXTENSION_NAME=torch2jax_cpp -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include/TH -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include/THC -isystem /usr/local/software/spack/spack-rhel8-20210927/opt/spack/linux-centos8-zen2/gcc-9.4.0/cuda-11.4.0-3hnxhjt2jt4ruy75w2q4mnvkw7dty72l/include -isystem /usr/local/software/spack/spack-views/rocky8-a100-20221118/python-3.11.9/gcc-11.3.0/57ayfwmekjj72xepaojw7tke676c7sqd/include/python3.11 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 --compiler-options '-fPIC' -DTORCH2JAX_WITH_CUDA -O3 -std=c++17 -c /venv_fresh2/lib/python3.11/site-packages/torch2jax/cpp/main.cu -o main.cuda.o 
FAILED: main.cuda.o 
/usr/local/software/spack/spack-rhel8-20210927/opt/spack/linux-centos8-zen2/gcc-9.4.0/cuda-11.4.0-3hnxhjt2jt4ruy75w2q4mnvkw7dty72l/bin/nvcc --generate-dependencies-with-compile --dependency-output main.cuda.o.d -ccbin /usr/local/software/spack/spack-views/rocky8-a100-20221118/gcc-11.3.0/gcc-11.3.0/i4xnp7h53ty3rosv2mjoycl2d6cyjddv/bin/gcc -DTORCH_EXTENSION_NAME=torch2jax_cpp -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include/TH -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include/THC -isystem /usr/local/software/spack/spack-rhel8-20210927/opt/spack/linux-centos8-zen2/gcc-9.4.0/cuda-11.4.0-3hnxhjt2jt4ruy75w2q4mnvkw7dty72l/include -isystem /usr/local/software/spack/spack-views/rocky8-a100-20221118/python-3.11.9/gcc-11.3.0/57ayfwmekjj72xepaojw7tke676c7sqd/include/python3.11 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 --compiler-options '-fPIC' -DTORCH2JAX_WITH_CUDA -O3 -std=c++17 -c /venv_fresh2/lib/python3.11/site-packages/torch2jax/cpp/main.cu -o main.cuda.o 
/usr/local/software/cuda/12.1/include/crt/sm_80_rt.h(109): error: more than one instance of overloaded function "__nv_associate_access_property_impl" has "C" linkage

1 error detected in the compilation of venv_fresh2/lib/python3.11/site-packages/torch2jax/cpp/main.cu".
[2/3] /usr/local/software/spack/spack-rhel8-20210927/opt/spack/linux-centos8-zen2/gcc-9.4.0/cuda-11.4.0-3hnxhjt2jt4ruy75w2q4mnvkw7dty72l/bin/nvcc --generate-dependencies-with-compile --dependency-output gpu_impl.cuda.o.d -ccbin /usr/local/software/spack/spack-views/rocky8-a100-20221118/gcc-11.3.0/gcc-11.3.0/i4xnp7h53ty3rosv2mjoycl2d6cyjddv/bin/gcc -DTORCH_EXTENSION_NAME=torch2jax_cpp -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include/TH -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include/THC -isystem /usr/local/software/spack/spack-rhel8-20210927/opt/spack/linux-centos8-zen2/gcc-9.4.0/cuda-11.4.0-3hnxhjt2jt4ruy75w2q4mnvkw7dty72l/include -isystem /usr/local/software/spack/spack-views/rocky8-a100-20221118/python-3.11.9/gcc-11.3.0/57ayfwmekjj72xepaojw7tke676c7sqd/include/python3.11 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 --compiler-options '-fPIC' -DTORCH2JAX_WITH_CUDA -O3 -std=c++17 -c /venv_fresh2/lib/python3.11/site-packages/torch2jax/cpp/gpu_impl.cu -o gpu_impl.cuda.o 
FAILED: gpu_impl.cuda.o 
/usr/local/software/spack/spack-rhel8-20210927/opt/spack/linux-centos8-zen2/gcc-9.4.0/cuda-11.4.0-3hnxhjt2jt4ruy75w2q4mnvkw7dty72l/bin/nvcc --generate-dependencies-with-compile --dependency-output gpu_impl.cuda.o.d -ccbin /usr/local/software/spack/spack-views/rocky8-a100-20221118/gcc-11.3.0/gcc-11.3.0/i4xnp7h53ty3rosv2mjoycl2d6cyjddv/bin/gcc -DTORCH_EXTENSION_NAME=torch2jax_cpp -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include/TH -isystem /venv_fresh2/lib/python3.11/site-packages/torch/include/THC -isystem /usr/local/software/spack/spack-rhel8-20210927/opt/spack/linux-centos8-zen2/gcc-9.4.0/cuda-11.4.0-3hnxhjt2jt4ruy75w2q4mnvkw7dty72l/include -isystem /usr/local/software/spack/spack-views/rocky8-a100-20221118/python-3.11.9/gcc-11.3.0/57ayfwmekjj72xepaojw7tke676c7sqd/include/python3.11 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 --compiler-options '-fPIC' -DTORCH2JAX_WITH_CUDA -O3 -std=c++17 -c /venv_fresh2/lib/python3.11/site-packages/torch2jax/cpp/gpu_impl.cu -o gpu_impl.cuda.o 
/usr/local/software/cuda/12.1/include/crt/sm_80_rt.h(109): error: more than one instance of overloaded function "__nv_associate_access_property_impl" has "C" linkage

1 error detected in the compilation of venv_fresh2/lib/python3.11/site-packages/torch2jax/cpp/gpu_impl.cu".
ninja: build stopped: subcommand failed.
rdyro commented 6 months ago

From the trace, it might be that your pytorch is finding the wrong cuda compiler on your system: CUDA 11.4 instead of the 12.1

In those cases PyTorch advises you set the environment variable CUDA_HOME explicitly, either via export CUDA_HOME=... or you can do it from python via import os; os.environ["CUDA_HOME"] = "..." before any other imports.

In your case, the correct CUDA home is probably export CUDA_HOME="/usr/local/software/cuda/12.1"

Let me know if the compilation succeeds with that change, if it's not that I'll try to dig into it deeper sometime soon.

zwei-beiner commented 6 months ago

Trying export CUDA_HOME="/usr/local/software/cuda/12.1" (thanks for giving me the directory) changed the error messages so this is likely at least part of the problem. The new error messages state that gcc versions > 10 are unsupported (gcc 11.2 is installed).

It seems quite likely that there's some dependency problem between CUDA, GCC, python and pytorch on the machine I'm using, especially given that it's working on your machine. As I don't have enough privileges to start trying different installations, it seems to be a bit of a dead end here.

Perhaps you could merge at least the cpp versioning and the with torch.no_grad() into the main branch.

zwei-beiner commented 6 months ago

In any case, the "temporary" solution of moving things between GPUs on the pytorch side works for a multi-gpu setup.

rdyro commented 6 months ago

Hmmm, that's unfortunate you're still getting compilation errors. I wish I could offer more help, but in torch2jax I'm almost entirely deferring to the PyTorch extension compilation mechanism. Perhaps they can introduce more control over CUDA loading soon.

I'm glad the temporary solution is working, I'll keep thinking of how JAX sharding/device mapping could be used to inform PyTorch calling.

I'll merge all the code changes your feedback and suggestions contributed later today. I was hoping to confirm you can successfully compile on your side.

I'll keep this issue open until tomorrow unless you have any other questions. If anything else comes up later feel free to comment here or open another issue!

zwei-beiner commented 6 months ago

Ok, I'll run some final tests and let you know if things are working hopefully later today.

rdyro commented 6 months ago

I merged the changes your feedback contributed into main. Let me know if you are still experiencing issues.

I published a new PyPI version of this package here that includes those changes: https://pypi.org/project/wrap-torch2jax/

rdyro commented 6 months ago

Closing, the first version of multi-gpu support should work!