Open adam-hartshorne opened 4 days ago
The multi-gpu there refers to correctly calling functions on a particular GPU, but unfortunately not on multi-GPU shared arrays (the paradigm from JAX). I have to learn more about sharding in torch to think how to support a sharded array function in torch.
I'm currently using the NVIDIA C++ functionality for detecting which GPU the data is on, so as long as torch2jax
is called from shard_map (exactly!) it should hopefully work correctly. I'm planning on testing this in the coming days. (I'll leave the issue open until I can test it)
Tangentially, this weekend, I finished porting torch2jax
(in the new-ffi branch) to the new FFI interface, so long-term support should be assured now.
Great job on getting FFI interface working. I just tried installing from that branch and doing a fresh recompile on one of my use cases and all seems to work seamlessly.
I notice you have a stable branch for multi-gpu testing. I was just wondering if torch2jax does actually work out of the box when using what I believe is now the standard JAX multi-gpu paradigm of sharding i.e.
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html https://jax.readthedocs.io/en/latest/notebooks/shard_map.html