rasbt / LLMs-from-scratch

Implement a ChatGPT-like LLM in PyTorch from scratch, step by step
https://www.amazon.com/Build-Large-Language-Model-Scratch/dp/1633437167
Other
34.13k stars 4.18k forks source link

Best practices for memory efficient weight loading tutorial #402

Closed mikaylagawarecki closed 1 month ago

mikaylagawarecki commented 1 month ago

Bug description

Thanks for putting together this great tutorial and showing the pros and cons of each of the available options for model loading. I want to add a caveat re the section added on mmap.

My recommendation on the best practices for loading the model memory efficiently, where I define memory efficiency as

would be the following

def best_practices():
  with torch.device("meta"):
      model = GPTModel(BASE_CONFIG)

  model.load_state_dict(
          torch.load("model.pth", map_location=device, weights_only=True, mmap=True),
          assign=True
      )

  print_memory_usage()

peak_memory_used = memory_usage_in_gb(best_practices)
print(f"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB")

Note that this will print

Maximum GPU memory allocated: 6.4 GB
-> Maximum CPU memory allocated: 6.0 GB

At which point, I expect you to say, hey! 6.0GB is more than the CPU memory used by the load_sequentially_with_meta example.

Agreed! mmap is a syscall and hence we do not have fine-grained control over exactly how much CPU RAM will be used. However, the nice(!) thing about mmap=True is that you should be able to load your model regardless of the user's limitations on CPU RAM :)

What actually happens when setting mmap=True + map_location=device is that the checkpoint file will be mmaped, and then slices of this (corresponding to each storage) will be sent to device While I don't know of a good way to demonstrate this with an ipynb (resource.setrlimit(rss) doesn't actually work), if you launch a docker container with CPU RAM limited, I expect you should be able to see this.

So I would not recommend that a user has to save and load each parameter tensor separately as is the case in section 7

Let me know what you think :)

rasbt commented 1 month ago

Thanks a lot for opening this issue, I really appreciate these insights here! I think I now understand better why the mmap approach didn't look so great in practice...it's basically an "on-demand" thing, and the machine has too much memory here so the mmap function (smartly) doesn't do much here.

I just updated the notebook and flagged the "mmap" method as the recommended one in the section header. Thanks again!