huggingface / Google-Cloud-Containers

Hugging Face Deep Learning Containers (DLCs) for Google Cloud
https://hf.co/docs/google-cloud
Apache License 2.0
127 stars 16 forks source link

Dockerfile for JAX GPU support #10

Closed shub-kris closed 8 months ago

shub-kris commented 9 months ago

This PR adds JAX GPU support for DLC.

It is very similar to the GPU Dockerfile, the only addition is installing JAX and FLAX.

shub-kris commented 9 months ago

I tried building from these two as base images too:

# FROM nvcr.io/nvidia/jax:23.10-py3 #Dropped it because updating flax, jax was causing a problem
# FROM nvcr.io/nvidia/cuda:12.2.2-cudnn8-devel-ubuntu20.04 # Here neded to install python, and transformers was installing a pytorch, so debugging could be problem is pytorch is not fixed.

So, decided to use the same base image as we used with GPU container.

shub-kris commented 9 months ago

As, flax weights aren't there, so couldn't test loading of weights into for Gemma. Was trying to run this basically: https://github.com/huggingface/new-model-addition-golden-gate/blob/2126d1cae6e6b26456d1f0322d2db94f7eeac426/tests/models/golden_gate/test_modeling_flax_golden_gate.py#L236

But, the imports were working fine. I will ask @ArthurZucker on the timeline for Flax weights.

I tested running Flax in general by running this:

git clone https://github.com/huggingface/transformers
cd transformers
git checkout tags/v4.37.2
cd examples/flax/text-classification
export TASK_NAME=mrpc

python run_flax_glue.py \
  --model_name_or_path bert-base-cased \
  --task_name ${TASK_NAME} \
  --max_seq_length 128 \
  --learning_rate 2e-5 \
  --num_train_epochs 3 \
  --per_device_train_batch_size 4 \
  --eval_steps 100 \
  --output_dir ./$TASK_NAME/ 
shub-kris commented 9 months ago

transformers then installs torch.

philschmid commented 9 months ago

torch, so debugging could be problem is pytorch is not fixed.

This should not be the case, please check with the transformers team. The JAX container should not have pytorch installed.

philschmid commented 9 months ago

Did you keep in the "version" where you removed pytorch?

ARG DIFFUSERS='0.26.1'
ARG PEFT='0.8.2'
ARG TRL='0.7.10'
ARG BITSANDBYTES='0.42.0'
ARG ACCELERATE='0.27.0'
ARG SENTENCE_TRANSFORMERS='2.3.1'
ARG DEEPSPEED='0.13.1'
shub-kris commented 9 months ago

Changed the order of installation as recommended in the transformers chat.

shub-kris commented 9 months ago

@philschmid good to merge?