pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
80.43k stars 21.6k forks source link

[DCP] DCP does not support objects which are lazy initialized. #126881

Open LucasLLC opened 1 month ago

LucasLLC commented 1 month ago

🚀 The feature, motivation and pitch

DCP lacks support for objects which are lazy initialized. E.g.:

import copy
import torch.distributed.checkpoint as dcp
dl_state = {"dl_state": {"data": {"a": "b"}}}

checkpoint_id = "tmp_ckpt"
dcp.save(state_dict=dl_state, checkpoint_id=checkpoint_id)

# simulates an object which is lazy initialized (in this case 'data' does not exist)
dl_state_loaded = copy.deepcopy(dl_state)
dl_state_loaded["dl_state"].pop("data")
dcp.load(state_dict=dl_state_loaded, checkpoint_id=checkpoint_id)

print(dl_state_loaded)
assert "data" in dl_state_loaded and "a" in dl_state_loaded["data"]   # fails

Since lazy initialization is a fairly common practice (e.g. optimizers, dataloaders), this ends up being pretty problematic. Additionally, although we will save the entire state, DCP loads without any warning when keys exist in the serialized checkpoint but do not exist in the local state_dict. This can essentially be considered a 'silent failure to load', and is pretty confusing to users.

Proposal:

The main motivation for not supporting objects which are lazily initialized is that DCP loads modules 'in-place'. Essentially, if an object does not exist as part of the local state_dict, there is nothing for DCP to copy into, and thus those objects are ignored.

An alternative is to assume if objects do not exist in the local state_dict, we should still make a best effort at loading the object by creating the object and inserting it into the local state dict. Additionally, we can add a configuration to optionally raise warnings for 'extra' keys.

Alternatives

The alternative for this proposal is to force all users of DCP to avoid lazy initialization patterns. In the above example, we could recommend the entire "data" object should be serialized, and should exist as a bytes object at load (I'm going to admit this is a little hand-wavy, but that is the general idea). The important distinction is that the object would always exist and would be properly replaced at load time.

From a UX perspective I don't believe this is a good idea because: a) Users should expect that a load api handles this case. b) lazy initialization is a fairly common practice, and this behavior is fairly subtle and surprising to users (which means it's bad UX)

Additional context

I'm not suggesting we do this as a replacement to the current lazy initialization logic in optimizers. This might be a step in the right direction for avoiding lazy initialization (or it could not be), but either way optimizers present additional complexity since currently they need to be sharded before calling load.

In practice, this doesn't really apply to most other objects such as dataloaders since the expectation is that if the object does support re-sharding then: a) they are either DTensors and exist before calling dcp.load b) manage their own sharding logic in load_state_dict

kirtiteja commented 1 month ago

This behavior is well documented here https://pytorch.org/docs/stable/distributed.checkpoint.html.

It operates in place, meaning that the model should allocate its data first and DCP uses that storage instead.

Is it really confusing? can you share more context on the exact issue we saw? Do you think we can address this by improving documentation on the load API?

For a distributed checkpoint, raising warnings on missing keys or loading them implicitly is problematic as not all ranks load all data. The load API is also used to load only some params (partial checkpoint loads for post process or introspection) and it would be annoying to see warnings for all keys in these cases.

LucasLLC commented 1 month ago

This behavior is well documented here https://pytorch.org/docs/stable/distributed.checkpoint.html.

It operates in place, meaning that the model should allocate its data first and DCP uses that storage instead.

Is it really confusing? can you share more context on the exact issue we saw? Do you think we can address this by improving documentation on the load API?

Granted, some additional documentation would be helpful for this issue, but I still don't believe it's the expected behavior from a load API. I think in general, users see a load as 'grab everything which I previously saved'. Recently this has come up for certain cases where the dataloader is not fully initialized and we are attempting to restore state. DCP silently fails to load the dataloader state, which causes errors in the ensuing dataloader.load_state_dict call.

I would argue it's only useful for DCP not to load the objects which are present on the checkpoint when this is explicitly called out, as is the case for partial loading.

For a distributed checkpoint, raising warnings on missing keys or loading them implicitly is problematic as not all ranks load all data.

I don't really believe we support loading entire objects in some ranks but not others. The closest thing we support in DCP is through DTensor, where we only the local tensor shards into DTensor.

The load API is also used to load only some params (partial checkpoint loads for post process or introspection) and it would be annoying to see warnings for all keys in these cases.

I agree, but we could always account for this in our partial load API's and avoid throwing warnings when we know this is the intended behavior. Furthermore, even in the case of partial loads, say we want to only certain modules from a checkpoint, shouldn't we at least warn if parts of that module are missing locally?

E.g.

locally we have something like:

{
    module.net1,
    module.net2
}

and in the checkpoint we have:

{
    module.net1,
    module.net2,
    module.net3,
    optimizer.state.0
     optimizer.state.1
}

If we do `dcp.state_dict_loader._load_from_state_dict_keys("module"), I think users would expect us to raise since net3 exist in the checkpoint but not locally.

Overall I think the current design makes it very easy for users to fail in subtle ways. Also, if all of this is guarded behind some explicit configurations, I don't see this as a large risk

cc @kirtiteja

andrewkho commented 1 month ago

+1 to this issue, for dataloading this currently requires some pretty manual workarounds, and since the dataloader state relies on an API that users are responsible to define, this requirement introduces some footguns that IMO are not obvious. it also forces eager initialization of some things (such as multiprocessing iterator) that may be in the critical path for some fail-fast checks outside of the dataloader