Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.15k stars 77 forks source link

load/save_state_dict hooks for early transforms #483

Closed t-vi closed 1 month ago

t-vi commented 4 months ago

Currently, we do a number of things in the distributed early transforms:

Looking at the module lifetime, it would be awesome to have a mechanism load/save state dicts targeted at the original model. Similar to the transformations, this should be composable across several transforms. For efficiency reasons, this should be able to process one parameter (or those of one submodule?) at a time.

This would likely also help with the memory issue #478 because we could then leave things on meta until we need it (at the end of the transforms) if we need several steps to merge things.

So the goal would be to be able to do

# m on meta dev
cm = fsdp(quantize(thunder.jit(m)))
cm.load_state_dict(state_dict_of_m)
crcrpar commented 4 months ago

nit: isn't it fsdp(quantize(thunder.jit(m)))?

t-vi commented 4 months ago

Probably. :)

t-vi commented 4 months ago

Discussing with Masaki: We'll make the early transforms a class with methods to achieve the various transformations and also giving a good repository of the details of the transformation.

crcrpar commented 4 months ago

cc @kshitij12345

t-vi commented 1 month ago

The goal in the title has been achieved. And we have load_original_state_dict. The composability of FSDP with other transforms is still WIP and I will file a separate issue about it.