aniquetahir / JORA

JORA: JAX Tensor-Parallel LoRA Library
https://anique.org/JORA/
Other
21 stars 1 forks source link

8-bit training #1

Open yctam opened 3 months ago

yctam commented 3 months ago

Does the codebase support 8-bit training similar to peft library?

I was trying to fine-tune on llama2-7b on 24Gb 4090 cards. Below is the error I got: File "/home/nlp/JORA/examples/train.py", line 14, in main() File "/home/nlp/JORA/examples/train.py", line 10, in main train_lora(config, dataset, 'checkpoints') File "/home/nlp/JORA/jora/common.py", line 246, in train_lora lora_params, opt_state, total_loss, loss, key = train_step_lora(lora_params, loraConfig, params, opt_state, total_loss, data_batch, key) jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 16320586680 bytes. BufferAssignment OOM Debugging. BufferAssignment stats: parameter allocation: 3.58GiB constant allocation: 1.95MiB maybe_live_out allocation: 264.00MiB preallocated temp allocation: 15.20GiB total allocation: 19.04GiB

aniquetahir commented 3 months ago

For now its using bfloat16. The main reason being no bitsandbytes equivalent for JAX yet. However, there is also some potential for inclusion of 8bit through TransformerEngine.