AnswerDotAI / fsdp_qlora

Training LLMs with QLoRA + FSDP
Apache License 2.0
1.38k stars 185 forks source link

Running into CUDA out of memory with hqq_lora #25

Closed zabirauf closed 6 months ago

zabirauf commented 6 months ago

Setup:

I have 1x3090 and 1x4090 and I'm trying to follow the instructions in README.md to fine tune using HQQ but running into CUDA out of memory error

python train.py --model_name meta-llama/Llama-2-70b-hf --batch_size 2 --context_length 2048 --precision bf16 --train_type hqq_lora --use_gradient_checkpointing true --use_cpu_offload true --dataset alpaca --log_to wandb

Error

Traceback (most recent call last):
  File "/home/ml-curious/Documents/Projects/Opensource/fsdp_qlora/train.py", line 939, in <module>
    def main(
  File "/home/ml-curious/Documents/Projects/Opensource/fsdp_qlora/.venv/lib/python3.10/site-packages/fastcore/script.py", line 125, in call_parse
    return _f()
  File "/home/ml-curious/Documents/Projects/Opensource/fsdp_qlora/.venv/lib/python3.10/site-packages/fastcore/script.py", line 119, in _f
    return tfunc(**merge(args, args_from_prog(func, xtra)))
  File "/home/ml-curious/Documents/Projects/Opensource/fsdp_qlora/train.py", line 1010, in main
    mp.spawn(fsdp_main,
  File "/home/ml-curious/Documents/Projects/Opensource/fsdp_qlora/.venv/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 241, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
  File "/home/ml-curious/Documents/Projects/Opensource/fsdp_qlora/.venv/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 197, in start_processes
    while not context.join():
  File "/home/ml-curious/Documents/Projects/Opensource/fsdp_qlora/.venv/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 158, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/home/ml-curious/Documents/Projects/Opensource/fsdp_qlora/.venv/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 68, in _wrap
    fn(i, *args)
  File "/home/ml-curious/Documents/Projects/Opensource/fsdp_qlora/train.py", line 625, in fsdp_main
    parallel(load_and_quantize_parallel, weights.items(), n_workers=n_workers, threadpool=True,
  File "/home/ml-curious/Documents/Projects/Opensource/fsdp_qlora/.venv/lib/python3.10/site-packages/fastcore/parallel.py", line 117, in parallel
    return L(r)
  File "/home/ml-curious/Documents/Projects/Opensource/fsdp_qlora/.venv/lib/python3.10/site-packages/fastcore/foundation.py", line 98, in __call__
    return super().__call__(x, *args, **kwargs)
  File "/home/ml-curious/Documents/Projects/Opensource/fsdp_qlora/.venv/lib/python3.10/site-packages/fastcore/foundation.py", line 106, in __init__
    items = listify(items, *rest, use_list=use_list, match=match)
  File "/home/ml-curious/Documents/Projects/Opensource/fsdp_qlora/.venv/lib/python3.10/site-packages/fastcore/basics.py", line 66, in listify
    elif is_iter(o): res = list(o)
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator
    yield _result_or_cancel(fs.pop())
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
    return fut.result(timeout)
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 458, in result
    return self.__get_result()
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
  File "/usr/lib/python3.10/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/home/ml-curious/Documents/Projects/Opensource/fsdp_qlora/.venv/lib/python3.10/site-packages/fastcore/parallel.py", line 46, in _call
    return g(item)
  File "/home/ml-curious/Documents/Projects/Opensource/fsdp_qlora/train.py", line 609, in load_and_quantize_parallel
    load_and_quantize(model, name, param, **kwargs)
  File "/home/ml-curious/Documents/Projects/Opensource/fsdp_qlora/train.py", line 212, in load_and_quantize
    submodule.initialize()
  File "/home/ml-curious/Documents/Projects/Opensource/fsdp_qlora/.venv/lib/python3.10/site-packages/hqq/core/quantize.py", line 280, in initialize
    self.quantize(self.linear_layer.weight.data, **self.quant_config)
  File "/home/ml-curious/Documents/Projects/Opensource/fsdp_qlora/.venv/lib/python3.10/site-packages/hqq/core/quantize.py", line 382, in quantize
    W_q , meta = Quantizer.quantize(W, device=self.device, compute_dtype=self.compute_dtype, **weight_quant_params)
  File "/home/ml-curious/Documents/Projects/Opensource/fsdp_qlora/.venv/lib/python3.10/site-packages/hqq/core/quantize.py", line 71, in quantize
    if(optimize): scale, zero = Quantizer.optimize_weights(tensor=W, scale=scale, zero=zero, min_max=min_max, axis=axis, device=device)
  File "/home/ml-curious/Documents/Projects/Opensource/fsdp_qlora/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ml-curious/Documents/Projects/Opensource/fsdp_qlora/.venv/lib/python3.10/site-packages/hqq/core/optimize.py", line 166, in optimize_weights_proximal_legacy
    W_e   = shrink_op(W_f - W_r, beta)
  File "/home/ml-curious/Documents/Projects/Opensource/fsdp_qlora/.venv/lib/python3.10/site-packages/hqq/core/optimize.py", line 160, in <lambda>
    shrink_op = lambda x, beta,p=lp_norm: torch.sign(x)*torch.nn.functional.relu(torch.abs(x) - (1./beta)*torch.pow(torch.abs(x), p-1))
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 448.00 MiB. GPU 1 has a total capacity of 23.69 GiB of which 109.94 MiB is free. Including non-PyTorch memory, this process has 23.49 GiB memory in use. Of the allocated memory 22.66 GiB is allocated by PyTorch, and 549.12 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
divine-taco commented 6 months ago

I can reproduce this. BitsAndBytes works. But HQQ runs into memory issues even when setting context length very short (e.g. 512)

KeremTurgutlu commented 6 months ago

OOM seems to be happening in the model loading stage, we load and quantize pretrained weights in parallel, you can potentially manually set n_workers to a lower number here: https://github.com/AnswerDotAI/fsdp_qlora/blob/0b57d37e7579fc5663638bcf9ba373ab7d52396c/train.py#L622 and try again.

KeremTurgutlu commented 6 months ago

https://github.com/AnswerDotAI/fsdp_qlora/commit/cf614264fe7b1cdbadaf35172934a69d8d31e7de - should be able to load using HQQ without OOM on a 24GB GPU. Time to load increased from 5 mins to 10 mins. Will investigate further if this was caused by recent HQQ repo changes.