Closed t-vi closed 1 month ago
nit: isn't it fsdp(quantize(thunder.jit(m)))
?
Probably. :)
Discussing with Masaki: We'll make the early transforms a class with methods to achieve the various transformations and also giving a good repository of the details of the transformation.
cc @kshitij12345
The goal in the title has been achieved. And we have load_original_state_dict. The composability of FSDP with other transforms is still WIP and I will file a separate issue about it.
Currently, we do a number of things in the distributed early transforms:
Looking at the module lifetime, it would be awesome to have a mechanism load/save state dicts targeted at the original model. Similar to the transformations, this should be composable across several transforms. For efficiency reasons, this should be able to process one parameter (or those of one submodule?) at a time.
This would likely also help with the memory issue #478 because we could then leave things on meta until we need it (at the end of the transforms) if we need several steps to merge things.
So the goal would be to be able to do