Open kshitij12345 opened 1 month ago
Indeed, and this is tricky:
meta
is a great way to keep memory in check, butmeta
does not help us when we want to keep the weight's values. Maybe we should offer to replace the original weights with meta
equivalents as an option?triage review:
@mruberry
- we should be careful that retracing has the information needed (possibly observing original values) to work as expected
could you elaborate on what it means?
Two parts:
I would like to see the solution to #483 / #564 enabling moving materialization out of the sharding and do it before we run the model and propagate data through what we have for #483 (which needs to deal with "has been moved to meta", too).
As
fsdp(jit(...))
holds on to the original parameters as well as the sharded parameters, it can lead to higher memory usage. I think a work-around can be to initialize the original model onmeta
device. But if usingmeta
is the only correct way then we should add a warning if user does otherwise.cc: @t-vi
cc @carmocca @awaelchli @crcrpar