Open jglaser opened 5 months ago
Thanks for the fix! May I ask what's your jax version and how did you test RCCL?
RCCL was there but recent XLA is having rapid changes on many collective parts. For example in the following, rccl related parts are removed, back and changed in a very short period...
considering the situation, we are monitoring the refactoring and once it's stable we add it back at once.
Hi @i-chaochen ... this was tested using a local application which was working with jax 0.4.4, the last version that I was successfully able to use with ROCm. I remember I did have to search for a working combination of jax branches and tensorflow forks.... there was no documentation or recommendation what is current and what is not. After that, jax underwent many changes from tensorflow to openxla and stopped working for a while. jax and xla version were latest HEADs of main as of time of the PR (see commit id above). Luckily, it looks like most ROCm related changes have been merged into jax/xla upstream, which makes it much easier to work from the official repositories. As mentioned in the PR, the only thing that's missing is a ROCm CI so that things don't keep breaking....
we're working on the upstream, but sometimes their changes are too frequent and the procedure of merged PR is not always as good as expected.
I recommend you to use our release jax, and it will be easier for you. https://github.com/ROCmSoftwarePlatform/jax/releases
we're working on the upstream, but sometimes their changes are too frequent and the procedure of merged PR is not always as good as expected.
I recommend you to use our release jax, and it will be easier for you. https://github.com/ROCmSoftwarePlatform/jax/releases
Good to know --- on a side note, I can't use the latest release version here because I am actually interested in a recent feature (jax.experimental.shard_alike)
I get the following error in jax, even though RCCL is linked.