Open dlwh opened 9 months ago
I think the current implementation might be using nightly rather than stable? After installing stable fused attention, I'm still getting errors of cannot import fused_attn from fused_attn
.
It seems to be defined in the main branch, but not in the stable branch here https://github.com/NVIDIA/TransformerEngine/blob/stable/transformer_engine/jax/fused_attn.py
ah probably. we're mostly using the docker container (which we need docs for still as well cc @vadam5 )
In the meantime, something like this ought to work
docker run -v levanter:/levanter --gpus="$SLURM_STEP_GPUS" -e CUDA_VISIBLE_DEVICES -e WANDB_MODE=offline --shm-size=16g -it ghcr.io/stanford-crfm/levanter/levanter_jax:radium_test
We will be releasing a pip wheel for TE in the coming months that should resolve the installation issues
(We also need to not break when TE isn't installed...)
WIP
Run through this:
https://developer.nvidia.com/cuda-downloads?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=22.04&target_type=deb_local
Then