google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.81k stars 2.72k forks source link

AMD+Jax Docker container incompatible with Singularity #19138

Open PhilipVinc opened 8 months ago

PhilipVinc commented 8 months ago

Description

(Sorry if this is not the right place for this issue.)

I've recently got access to France's ADASTRA machine AMD-based HPC calculator. It is not connected to the internet, so to get a python installation running I must create a singularity container somewhere else, upload it there and find the right incantation to make it work.

I created the following Singularity definition and built it with the command apptainer build container.sif container_definition.

$ cat container_definition
Bootstrap: docker
From: rocm/jax:latest-py3110

%post
    pip install jaxopt diffrax
    pip install matplotlib

However, the resulting image is broken because of the following error which happens both during container generation and whenever I try to run the container

source: /.singularity.d/env/10-docker2singularity.sh:2:8: invalid var name

Looking at the content of the auto-generated 10-docker2singularity.sh:2:8 I see the following first two lines

$ head -n 3 /.singularity.d/env/10-docker2singularity.sh
#!/bin/bash
export XLA_CLONE_DIR:/workspace/jax-xla-repo="${XLA_CLONE_DIR:/workspace/jax-xla-repo:-}"
export XLA_BRANCH="${XLA_BRANCH:-"rocm-jaxlib-v0.4.21"}"

And I believe the problem stems from that XLA_CLONE_DIR env variable being ill-defined.

This env variable, by the way, is not needed to run jax, so a great solution would be to remove it.

Do note that this issue was already reported in https://github.com/google/jax/issues/16997#issuecomment-1683519122 but no full solution was reported, and no action on the docker images was taken.

Is it possible to fix the Docker container such that the autogenerated singularity container works out of the box? This is causing me a lot of headaches, and unfortunately in academic HPC setting Docker is rarely used because of licensing issues and singularity/apptainer is the norm.

Thank you

What jax/jaxlib version are you using?

jax 0.4.21

Which accelerator(s) are you using?

AMD GPU

Additional system info?

No response

NVIDIA GPU info

No response

PhilipVinc commented 7 months ago

Hi @rahulbatra85, sorry for the ping, but this would simplify my work a lot... I saw you recently updated the docker image for ROCM so maybe you could take a look?

rahulbatra85 commented 7 months ago

@PhilipVinc You are right these are not needed during runtime. Only for building JAX.

I see this in the docker container(rocm/jax:rocm5.7.0-jax0.4.21-py3.11.0) though. Somehow they are getting modified in the singularity container.

env | grep XLA
XLA_REPO=https://github.com/ROCmSoftwarePlatform/xla.git
XLA_BRANCH=rocm-jaxlib-v0.4.21
XLA_CLONE_DIR=/workspace/jax-xla-repo

Also, there is 0.4.23 release now and in this docker image I don't see any XLA env vars set at all. docker pull rocm/jax:latest-py3110

/workspace# env | grep XLA
/workspace#
PhilipVinc commented 7 months ago

Thanks! So hopefully this is fixed now. Let me try and I'll report back

rlrs commented 5 months ago

This is still an issue using the latest images with Apptainer.

PhilipVinc commented 4 months ago

@rahulbatra85 this is still an issue.

PhilipVinc commented 3 weeks ago

@rahulbatra85 , @Ruturaj4 sorry for the re-bump, but this is still an issue. Do you have a timeline for addressing the wheels and releasing some manylinux-compliant wheels?

Ruturaj4 commented 3 weeks ago

Hi, @PhilipVinc yes, we are publishing manylinux-compliant wheels now! Could you please check one of our latest JAX containers? from rocm/jax docker repository?

PhilipVinc commented 3 weeks ago

@Ruturaj4 thanks, I will try those docker images now. By the way, where can I find the manylinux compliant wheels? I don't find them in https://github.com/ROCm/jax/releases or on PyPi...

Ruturaj4 commented 3 weeks ago

@PhilipVinc Yes, they are not there yet, we we are working on latest release (0.4.31). They will be available once our tests are done. however, you can find them installed in the container.

PhilipVinc commented 2 weeks ago

@Ruturaj4 unfortunately the only 'recent' container available in https://hub.docker.com/r/rocm/jax/tags (rocm6.2-jax0.4.23-py3.10) is 1) an outdated version of jax and 2) uses rocm 6.2 which we do not yet have access to on our HPC cluster (we are stuck with 6.0).

Ruturaj4 commented 1 week ago

@mrodden do you have any comments on this?

@PhilipVinc can you try docker pull rocm/jax:rocm6.0.0-jax0.4.26-py3.11.0 0.4.26 version is still old, however this will give you something to get started.

mrodden commented 1 week ago

@PhilipVinc After reading through the comments here, I am a bit confused if you need a Docker container image or the manylinux wheels. My guess is that you want the manylinux wheels to create your own Docker image to use on the system, but maybe a prebuilt image with JAX+ROCm might work for your use case.

AMD ROCm releases include a Docker container that has its own ROCm stack + JAX (and now jax ROCm plugin stuff) pre installed. These are built, tested by our QA folks and then shipped, so they are easier for AMD to support because we can do testing on the full ROCm+JAX stack. They lag a bit behind because of the code freeze for testing.

The manylinux wheel support is very new and we haven't gone through a full test cycle with ROCm folks with the wheel path yet. I was considering publishing some "community release" ones which don't go through full testing yet, but I was holding off a bit to make sure they would at least install and work for the ROCm builder guys at least. There is also the matter that these wheels include jax_rocm60_plugin and jax_rocm60_pjrt, which use the JAX PJRT plugin path, which is also pretty new.

I would like to understand your requirement that you can only use ROCm 6.0, because the docker images contain their own set of ROCM math libraries and also HIP runtime, so they should be independent of any other system deps. The only real thing we can't control outside the images is the Linux kernel driver, but that generally provides a fixed interface that doesn't change much.

PhilipVinc commented 1 week ago

Hi, thanks for getting back to me.

My objective is to run Jax on Ad Astra, a HPC with MI250 and MI300X GPUs.

https://dci.dci-gitlab.cines.fr/webextranet/index.html

I am Aware of no publicly-funded HPC supercomputer that allows users to use Docker images because of security concerns. The few that allow containers to be used require Singularity/Apptainer images.

Apptainer images can sometimes be obtained by a straightforward conversion process, however for your images that is not the case because there are some environment variables that break the conversion, so they have to be rebuilt from scratch, a lengthy process which is complex and easily goes wrong.

The best solution would be to have many Linux- compliant wheels properly packaged.

The wheels you published so far require a too recent version of GLIBC, which is not available on Ad Astra.

I would be very happy to beta-test your wheels and get you some feedback. An engineer of the HPC is working with me to compile Jax/lib on AdAstra, but he's hitting several issues.

PhilipVinc commented 1 week ago

As for my requirement of ROCM 6.0, it's because the version of ROCM available on Ad Astra is 6.0.

Using a container would allow me to sidestep this limitation.

But your Rocm/jax containers are broken when converted to singularity as per the original topic of this thread.

mrodden commented 1 week ago

@PhilipVinc Ok I understand now, thanks for the explanation.

I'm working on a 0.4.31 release at the moment but what I might do is put up the wheels as a pre-release or something on ROCm/jax repo so you could grab and test them.

Its also possible to build you own wheels with ROCm support now from the JAX repo. There is a script that will build JAX+jaxlib+rocm_plugin in a Docker container at build/rocm/ci_build. It requires Docker on system, but will pull the manylinux container, run a jax/jaxlib build, fix wheels so they are manylinux_2_28 and then place them into ./wheelhouse for you to copy/install/upload/whatever. Should be something like

cd jax
python3 build/rocm/ci_build --rocm-version 6.1.2 --python-versions 3.10,3.12 --compiler gcc dist_wheels

Also I might have a new ubu22 based image that might work better that I can push up for you to try. It does not depend on the hacks that the older ubu20.04 based one had.

PhilipVinc commented 4 days ago

thanks! If you do put up the pre-release wheels somewhere, do let me know. If you want to keep it private, feel free to email them to me (my email is shown in my GitHub profile)