rdyro / torch2jax

Wraps PyTorch code in a JIT-compatible way for JAX. Supports automatically defining gradients for reverse-mode AutoDiff.
https://rdyro.github.io/torch2jax/
MIT License
37 stars 1 forks source link

Multi-GPU Question #18

Open adam-hartshorne opened 4 days ago

adam-hartshorne commented 4 days ago

I notice you have a stable branch for multi-gpu testing. I was just wondering if torch2jax does actually work out of the box when using what I believe is now the standard JAX multi-gpu paradigm of sharding i.e.

https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html https://jax.readthedocs.io/en/latest/notebooks/shard_map.html

rdyro commented 3 days ago

The multi-gpu there refers to correctly calling functions on a particular GPU, but unfortunately not on multi-GPU shared arrays (the paradigm from JAX). I have to learn more about sharding in torch to think how to support a sharded array function in torch.

I'm currently using the NVIDIA C++ functionality for detecting which GPU the data is on, so as long as torch2jax is called from shard_map (exactly!) it should hopefully work correctly. I'm planning on testing this in the coming days. (I'll leave the issue open until I can test it)

Tangentially, this weekend, I finished porting torch2jax (in the new-ffi branch) to the new FFI interface, so long-term support should be assured now.

adam-hartshorne commented 3 days ago

Great job on getting FFI interface working. I just tried installing from that branch and doing a fresh recompile on one of my use cases and all seems to work seamlessly.