Lightning-AI / litgpt

Pretrain, finetune, deploy 20+ LLMs on your own data. Uses state-of-the-art techniques: flash attention, FSDP, 4-bit, LoRA, and more.
https://lightning.ai
Apache License 2.0
6.95k stars 733 forks source link

OptimizerArgs #1409

Closed rasbt closed 1 week ago

rasbt commented 3 weeks ago

This PR unbundles the OptimizerArgs approach from GaLore in #1192.

Todos

rasbt commented 3 weeks ago

Your jsonargparse example has been super helpful for understanding things a bit more @carmocca . Many thanks for this!

But maybe it's because it's Fri evening but my brain is just not working today. I've just been banging my head against how I would get this into the finetuning method's setup.

Adding the optimizer subclass to the parser and then calling the finetune command yields a

  File "/home/zeus/miniconda3/envs/cloudspace/bin/litgpt", line 8, in <module>
    sys.exit(main())
  File "/teamspace/studios/this_studio/litgpt/litgpt/__main__.py", line 145, in main
    fn(**kwargs)
TypeError: setup() got an unexpected keyword argument 'optimizer.class_path'

But we can't add optimizer argument to the finetuning setup signature because it's a duplicate command then. Conceptually, I am kind of stuck here.

Also, how would we get the args in

optimizer = instantiate_class(model.parameters(), init=args["optimizer"])

if we don't pass them on from the main() function in __main.py__. I suppose if I would do the

parser = jsonargparse.ArgumentParser()
parser.add_subclass_arguments(torch.optim.Optimizer, "optimizer", instantiate=False, fail_untyped=False, skip={"params"})
args = parser.parse_args()

in the finetuning script, but then it would erase all the previous arguments.

carmocca commented 2 weeks ago

As far as integrating into the scripts, I would:

Create an optimizer argument in https://github.com/Lightning-AI/litgpt/blob/36c6a77435d75872f525848ee1570467d120ae80/litgpt/finetune/lora.py#L40

To avoid the duplicate registration, you need to skip it when the function arguments are added https://github.com/omni-us/jsonargparse/blob/2de15ddfb1c02c2f7b3fe913ad11f13c5cb65dff/jsonargparse/_signatures.py#L166 https://github.com/Lightning-AI/litgpt/blob/36c6a77435d75872f525848ee1570467d120ae80/litgpt/__main__.py#L121

And call instantiate_class here https://github.com/Lightning-AI/litgpt/blob/36c6a77435d75872f525848ee1570467d120ae80/litgpt/finetune/lora.py#L185-L187

This should be enough to unblock you. The not-so-nice thing is that the CLI args structure leaks into the actual script, meaning that users who don't go through the CLI will have to create this dictionary manually.

rasbt commented 2 weeks ago

Awesome, thanks so much, this was great help! Figured it out now and got it to work. Many thanks, again learned something new!

rasbt commented 2 weeks ago

I now got it to work as follows:

litgpt finetune full \
  --checkpoint_dir checkpoints/EleutherAI/pythia-160m 

# Specify optimizer and optimizer args:
litgpt finetune full \
  --checkpoint_dir checkpoints/EleutherAI/pythia-160m \
  --optimizer  torch.optim.SGD \
  --optimizer.init_args.lr 1000

But I feel like the way I am passing the optimizer kwargs seems a bit hacky. Is this there a built-in/better way to handle it @carmocca ? The thing is that when I pass an --optimizer argument it also passes additional kwargs to the setup:

kwargs = {
    'optimizer.class_path': 'torch.optim.SGD',
    'optimizer.init_args.dampening': 0.0,
    'optimizer.init_args.differentiable': False,
    'optimizer.init_args.foreach': None,
    'optimizer.init_args.lr': 0.001,
    'optimizer.init_args.maximize': False,
    'optimizer.init_args.momentum': 0.0,
    'optimizer.init_args.nesterov': False,
    'optimizer.init_args.weight_decay': 0.0
}

That's why I added the parsing into class_path and init_args:

    optimizer_class_path = None
    optimizer_init_args = {}
    for key, value in list(kwargs.items()):
        if key.startswith("optimizer"):
            if "class_path" in key:
                optimizer_class_path = value
            elif "init_args" in key:
                init_arg_key = key.split(".")[-1]
                optimizer_init_args[init_arg_key] = value
            del kwargs[key]

Everything seems to work, but I wonder if there isn't a better way to do it?

carmocca commented 1 week ago

@rasbt I pushed a commit with what I would suggest. The str code path could be improved if we want to expose arguments like the learning rate outside of the CLI, but that should be straightforward to implement.

Also fyi, you don't need to specify the .init_args substring through command line

rasbt commented 1 week ago

The only caveat now is that the class path still needs to be specified. I.e., only specifying the learning rate doesn't work

litgpt finetune full  --optimizer.lr 200  --checkpoint_dir checkpoints/EleutherAI/pythia-160m

error: Parser key "optimizer":
  Not a valid subclass of Optimizer. Got value: NestedArg(key='lr', val='200')
  Subclass types expect one of:
  - a class path (str)
  - a dict with class_path entry
  - a dict without class_path but with init_args entry (class path given previously)

And the optimizer always needs to be specified explicitely

litgpt finetune full  --optimizer AdamW --optimizer.lr 200  --checkpoint_dir checkpoints/EleutherAI/pythia-160m

Do you know if that's a jsonargparse thing, @carmocca ? Because we already set a default value in the setup method I was thinking that this is a bit weird.

rasbt commented 1 week ago

I hope this is ready now @carmocca

carmocca commented 1 week ago

The azure failure does look real:

>       fit(fabric, devices, state, train_dataloader, val_dataloader, out_dir, tokenizer_dir, train, eval, optimizer)
E       TypeError: fit() takes 9 positional arguments but 10 were given

/__w/6/s/extensions/thunder/pretrain.py:229: TypeError
----------------------------- Captured stderr call -----------------------------
Missing logger folder: /tmp/pytest-of-root/pytest-0/test_pretrain0/out/logs/tensorboard
Seed set to 42
=========================== short test summary info ============================
FAILED tests/test_thunder_pretrain.py::test_pretrain - TypeError: fit() takes 9 positional arguments but 10 were given
rasbt commented 1 week ago

It does. Let me investigate ...

rasbt commented 1 week ago

Should be fixed for good now @carmocca . I can switch the link to the original tinystories now that you have seen the green checks haha 😆