karpathy / llm.c

LLM training in simple, raw C/CUDA
MIT License
23.59k stars 2.64k forks source link

init from scratch #243

Open karpathy opened 5 months ago

karpathy commented 5 months ago

Follow the GPT-2 reference .py file and initialize the weights in C from scratch in the exact same way. Allow init from scratch instead of init from checkpoint when building the GPT-2. Add argparse flag to configure which way to go. Ok to only change the mainline development file train_gpt2.cu.

Neruelin commented 5 months ago

I hope I understood the issue correctly. I understood it as the mainline train_gpt2.cu is missing the ability to train on fresh model weights unlike train_gpt2.py. The solution I came up with is reusing the GPT class of train_gpt2.py (which loads in the model weights) and also the write_model() function which serializes the model. I created a new utility script gen_base_weights_checkpoint.py to pull in this class and function and output fresh model weight checkpoint files. This script defaults to creating a checkpoint for the 124M param model because that was the model hardcoded in train_gpt2.cu, but can also output any of the available model types as a fresh checkpoint using the command line argument --model_type.

Additionally, train_gpt2.cu supports a new CLI arg -c (checkpoint) which allows setting the path to the checkpoint that is loaded. Previously it would be using the modified weights output by train_gpt2.py, now with the ability to create unmodified model weight files, we can train from scratch.

I was hoping this solution checks all the boxes here, its fairly straightforward and it should initialize in the exact same way as the python version (because its sharing the functionality). Also I updated the README.md to explain this from scratch option and provide a quickstart-like bash snippet.

Please let me know if there's any room for improvement. EDIT: link to PR

karpathy commented 5 months ago

Sorry to clarify I want to delete the need for Python in this repo. It's a nice to have for correctness checks but shouldn't be required. Right now it outputs the weights we init from, so it's kind of required.

Neruelin commented 5 months ago

Thanks for the quick reply, I'll work on a cuda only solution. A comment on the PR mentioned that the weights should be random aka no dependency on the HF model, just confirming, is that also the desired behavior?