Lightning-AI / litgpt

Pretrain, finetune, deploy 20+ LLMs on your own data. Uses state-of-the-art techniques: flash attention, FSDP, 4-bit, LoRA, and more.
https://lightning.ai
Apache License 2.0
6.85k stars 726 forks source link

LoRA multi-GPU no longer works if applying LoRA selectively #1385

Closed awaelchli closed 1 week ago

awaelchli commented 1 week ago
litgpt finetune lora --config config_hub/finetune/tiny-llama/lora.yaml 
--devices 2 --train.global_batch_size 8 --train.micro_batch_size 2 --train.max_steps 3 --lora_key False
Traceback (most recent call last):
  File "/home/adrian/.conda/envs/lit-gpt/bin/litgpt", line 8, in <module>
    sys.exit(main())
  File "/home/adrian/repositories/lit-gpt/litgpt/__main__.py", line 143, in main
    fn(**kwargs)
  File "/home/adrian/repositories/lit-gpt/litgpt/finetune/lora.py", line 143, in setup
    fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval)
  File "/home/adrian/.conda/envs/lit-gpt/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 845, in launch
    return self._wrap_and_launch(function, self, *args, **kwargs)
  File "/home/adrian/.conda/envs/lit-gpt/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 931, in _wrap_and_launch
    return to_run(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-gpt/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 936, in _wrap_with_setup
    return to_run(*args, **kwargs)
  File "/home/adrian/repositories/lit-gpt/litgpt/finetune/lora.py", line 199, in main
    fit(
  File "/home/adrian/repositories/lit-gpt/litgpt/finetune/lora.py", line 261, in fit
    validate(fabric, model, val_dataloader, dataclasses.replace(eval, max_iters=2))  # sanity check
  File "/home/adrian/.conda/envs/lit-gpt/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/adrian/repositories/lit-gpt/litgpt/finetune/lora.py", line 356, in validate
    logits = model(input_ids)
  File "/home/adrian/.conda/envs/lit-gpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-gpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-gpt/lib/python3.10/site-packages/lightning/fabric/wrappers.py", line 139, in forward
    output = self._forward_module(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-gpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-gpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-gpt/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 849, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-gpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-gpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/adrian/repositories/lit-gpt/litgpt/lora.py", line 555, in forward
    x = block(x, cos, sin, mask, input_pos)
  File "/home/adrian/.conda/envs/lit-gpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-gpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-gpt/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 849, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-gpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-gpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-gpt/lib/python3.10/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 168, in forward
    return self.checkpoint_fn(  # type: ignore[misc]
  File "/home/adrian/.conda/envs/lit-gpt/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-gpt/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-gpt/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-gpt/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 489, in checkpoint
    ret = function(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-gpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-gpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/adrian/repositories/lit-gpt/litgpt/model.py", line 180, in forward
    attention_output = self.attn(x_normed, cos, sin, mask, input_pos)
  File "/home/adrian/.conda/envs/lit-gpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-gpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/adrian/repositories/lit-gpt/litgpt/model.py", line 215, in forward
    qkv = self.attn(x)
  File "/home/adrian/.conda/envs/lit-gpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/adrian/.conda/envs/lit-gpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/adrian/repositories/lit-gpt/litgpt/lora.py", line 440, in forward
    lora = self.zero_pad(after_B) * self.scaling  # (64, 64, 256) after zero_pad (64, 64, 384)
  File "/home/adrian/repositories/lit-gpt/litgpt/lora.py", line 345, in zero_pad
    self._lora_ind_cache[result.device] = lora_ind = self._lora_ind.to(result.device)
NotImplementedError: Cannot copy out of meta tensor; no data!

On main, this fails. On the commit before #1374 it works.

1374 made lora_ind a tensor but doesn't re-init it in reset_parameters().

rasbt commented 1 week ago

I think this might also be related to #1378

carmocca commented 1 week ago

Should be fixed by #770

awaelchli commented 1 week ago

I ran the example and the fix works, thanks. I left a suggestion for an improved test that could catch further instances of similar issues.