stanford-crfm / levanter

Legible, Scalable, Reproducible Foundation Models with Named Tensors and Jax
https://levanter.readthedocs.io/en/latest/
Apache License 2.0
522 stars 84 forks source link

Add docs for installing TE from source #518

Open dlwh opened 9 months ago

dlwh commented 9 months ago

(We also need to not break when TE isn't installed...)

WIP

Run through this:

https://developer.nvidia.com/cuda-downloads?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=22.04&target_type=deb_local

Then

wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get -y install cudnn9-cuda-12
export PATH=$CUDA_HOME/bin/:$PATH
NVTE_FRAMEWORK=jax pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
Helw150 commented 8 months ago

I think the current implementation might be using nightly rather than stable? After installing stable fused attention, I'm still getting errors of cannot import fused_attn from fused_attn.

It seems to be defined in the main branch, but not in the stable branch here https://github.com/NVIDIA/TransformerEngine/blob/stable/transformer_engine/jax/fused_attn.py

dlwh commented 8 months ago

ah probably. we're mostly using the docker container (which we need docs for still as well cc @vadam5 )

In the meantime, something like this ought to work

docker run -v levanter:/levanter --gpus="$SLURM_STEP_GPUS" -e CUDA_VISIBLE_DEVICES -e WANDB_MODE=offline  --shm-size=16g -it ghcr.io/stanford-crfm/levanter/levanter_jax:radium_test
sbhavani commented 3 months ago

We will be releasing a pip wheel for TE in the coming months that should resolve the installation issues