google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29k stars 2.65k forks source link

[ROCm] JAX-ROCm docker images #7598

Open reza-amd opened 2 years ago

reza-amd commented 2 years ago

Hi,

As part of our effort to support JAX on ROCm framework, we have published our preview release under the following DockerHub repository. https://hub.docker.com/repository/docker/rocm/jax

We appreciate it if you help us with the following items:

@hawkinsp

mattjj commented 2 years ago

Thanks so much for this amazing work, and for bringing it to our attention!

Announcing this release to the JAX community to try it and gives us feedback to improve our support

I'll announce this to Google-internal users. We don't have a clear communication line to external folks... got any suggestions for what we should do on this front? We could mention it in the README.

Helping us to setup CI-builds similar to your internal infrastructure Helping us to release JAX on ROCm as Python Wheels

@yashk2810 could you weigh in on this? (Note there's also #7323 as pointed out in the OP.)

(Assigning to Yash for now, to follow up on this point.)

Guiding us to pick some representative benchmarks for performance tuning and detecting missing features

I'll ping some folks about this.

yashk2810 commented 2 years ago

This issue is being tracked in [ROCm] running unit test in parallel #7323

I replied on this issue about OSSing BUILD files for testing.

Helping us to setup CI-builds similar to your internal infrastructure

Once BUILD files are opensourced, I can look into running bazel tests using internal infra. I can't guarantee a timeline but when BUILD files are opensourced for testing that should atleast give you a way to test using bazel.

How does that sound?

yashk2810 commented 2 years ago

Helping us to release JAX on ROCm as Python Wheels

This is interesting. Maybe I can hook something for this when I work on the release process for JAX. But if you have a way to do this, you can try that out!

reza-amd commented 2 years ago

@mattjj , Thanks so much for your attention on this matter.

I'll announce this to Google-internal users. We don't have a clear communication line to external folks... got any suggestions for what we should do on this front? We could mention it in the README.

Mentioning this in README would be great.

@yashk2810

Once BUILD files are opensourced, I can look into running bazel tests using internal infra. I can't guarantee a timeline but when BUILD files are opensourced for testing that should atleast give you a way to test using bazel. How does that sound?

Thanks so much for your help. It sounds good. Meanwhile we still use the approach mentioned in the documentation.

brettkoonce commented 2 years ago

See also: https://github.com/google/jax/tree/main/build/rocm

brettkoonce commented 2 years ago

See also: #2012

brettkoonce commented 8 months ago

@reza-amd is it possible to get a build with 5.7.1 support? I am interested in testing 7900xtx compatibility!

brettkoonce commented 7 months ago

@rahulbatra85 any luck with your updates / is the plan to wait on this till 6.0? I got the pytorch image working locally w/ 5.7.1 and was able to train simple models!

rahulbatra85 commented 7 months ago

@brettkoonce We have been releasing wheels and docker images for ROCm for a while now. Please see this https://github.com/ROCmSoftwarePlatform/jax/releases

hawkinsp commented 7 months ago

@rahulbatra85 Do you want to send a PR improving https://jax.readthedocs.io/en/latest/installation.html ? I didn't quite know what to put there, and I think we could do a better job pointing to your releases.

rahulbatra85 commented 7 months ago

@hawkinsp yes, will update it.

brettkoonce commented 7 months ago

@rahulbatra85 Thank you for the link to the images. I have been trying to use them for the past month or two without success, and so was assuming that my card (7900 xtx) was still not officially supported (eg 5.7.1 was required). I filed a bug (#18747) with notes on what I am seeing on my machine, would appreciate any advice!

rahulbatra85 commented 7 months ago

@brettkoonce Sorry, I misunderstood your question. Currently, JAX support for 7900 XTX is not there, but it's in our plan to support it with ROCm 6.xxx. Current best estimate is sometime next year.

I will keep you posted when I have an update!

Thanks!

brettkoonce commented 6 months ago

@rahulbatra85 thanks for the update! looking forward to it!

brettkoonce commented 2 months ago

@rahulbatra85 Thank you for the updated docker images, I am able to train networks using jax and ROCm 6.1!