Open mrshenli opened 4 years ago
Adding two cents on ShardedOptimizer
, since I've been working on something similar in fairscale: not all concepts from the upstream pytorch optimizer translate 1:1, since there are N+1 states basically, versus 1 state. Each rank has its own state, and there's a virtual, distributed state.
Some of the pytorch optimizer interfaces implicitly assume that there's a single state, for instance the .param_groups
attribute, or state_dict()
.
It's not just theoretical, because some frameworks (ClassyVision for instance, but probably others) have taken advantage of these interfaces and are actively using them. Checkpointing becomes an issue if you pull state_dict()
from a given replica and assume that you have the full state, or changing LRs per layer can also be an issue if you're manipulating the param_groups
from a rank and it does not have your layer.
I don't think that all of this is blocking per say, there are more or less elegant ways to make that work, but I'm just adding a 1.1 phase basically : identify the new APIs really needed and agree on something which works well for everyone. Curious to hear your thoughts !
Regarding ShardedOptimizer
and the OSS
implementation in fairscale: another big change in the API is that users are expected to call consolidate_state_dict
on the optimizer at the end of every epoch. We are able to hide this in Classy Vision cause our optimizer API is already aware of epochs (for param scheduling), but for general PyTorch users this is another extra step.
Regarding
ShardedOptimizer
and theOSS
implementation in fairscale: another big change in the API is that users are expected to callconsolidate_state_dict
on the optimizer at the end of every epoch. We are able to hide this in Classy Vision cause our optimizer API is already aware of epochs (for param scheduling), but for general PyTorch users this is another extra step.
Looping back other discussions here: one solution for this particular .state_dict()
conformance issue could be to use RRef from torch RPC, so that the full state could materialize on whatever rank is pulling it. A short term blocker for this is that RPC and DDP need two separate init for now (see #41614) and that's cumbersome, but that may not be a fundamental blocker.
Some of the pytorch optimizer interfaces implicitly assume that there's a single state, for instance the .param_groups attribute, or state_dict().
For .param_groups
, do we still need the broadcast
op in the optimizer if we have ShardedDataParallel
? Looks like SharedDataParallel
can work with local optimizers and the local optimizer does not need to do broadcast anymore as the next forward pass can pull updated parameters?
For state_dict()
, yep, RRef should be able to support this use case if the requirement is to collect all model states to one process. I wonder if ProcessGroup::send/recv
would be sufficient as well? We can, e.g., specify a root rank that all processes send their model states to that root. Compared to the RRef solution, the difference would be that all processes need to call state_dict()
and only the root passes it to torch.save
, where as the RRef solution only requires the root calling state_dict()
.
Some of the pytorch optimizer interfaces implicitly assume that there's a single state, for instance the .param_groups attribute, or state_dict().
For
.param_groups
, do we still need thebroadcast
op in the optimizer if we haveShardedDataParallel
? Looks likeSharedDataParallel
can work with local optimizers and the local optimizer does not need to do broadcast anymore as the next forward pass can pull updated parameters?
No you're right, this is being worked on actually, I think fairscale's implementation can be improved indeed
For
state_dict()
, yep, RRef should be able to support this use case if the requirement is to collect all model states to one process. I wonder ifProcessGroup::send/recv
would be sufficient as well?
Right now NCCL does not support that unfortunately, I think that this is being changed though
Right now NCCL does not support that unfortunately, I think that this is being changed though
Yep, code is available in ProcessGroupNCCL
but disabled for OSS as we are still trying to figure out a weird error in CI test (#39984 #42514).
Regarding the API, if this has to be built on top of RPC and RRef, instead of passing in a cpu_model_shard
, we can use RemoteModule
, which is basically an RRef to a module on a different process.
Besides the concern of having both init_process_group
and init_rpc
(I am confident we can fix that), is there any other downside of letting Sharded*
sits on top of both RPC and c10d? Or should we just build it on top of RPC, as we can implement gather/broadcast on RPC as well?
Right now NCCL does not support that unfortunately, I think that this is being changed though
Yep, code is available in
ProcessGroupNCCL
but disabled for OSS as we are still trying to figure out a weird error in CI test (#39984 #42514).Regarding the API, if this has to be built on top of RPC and RRef, instead of passing in a
cpu_model_shard
, we can useRemoteModule
, which is basically an RRef to a module on a different process.Besides the concern of having both
init_process_group
andinit_rpc
(I am confident we can fix that), is there any other downside of lettingSharded*
sits on top of both RPC and c10d? Or should we just build it on top of RPC, as we can implement gather/broadcast on RPC as well?
I think it makes sense, it's a significant departure from the existing fairscale implementation (I'm not sure about deepspeed), but longer term it's the nicer solution IMO. I don't know of any significant drawback and it makes it possible to target a drop-in solution (same interface as non-sharded) I think
model = torch.distributed.optim.ShardedDataParallel(model)
Is the model
passed into ShardedDataParallel
already a sharded version of the model where it is the user's responsibility to do this sharding? Looking at deepspeed.initialize
, this does seem to be their intention.
class ShardedOptimizer(torch.optim.Optimizer): def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: loss = self.optim.step(closure=closure) for rank, param_groups in enumerate(self.partition_parameters()): for param_group in param_groups: for param in param_group["params"]: dist.broadcast(param, rank, group=self.group) return loss
I'm not sure I understood why this implementation needs to broadcast all parameters to all ranks? Don't we only need to broadcast the parameters one shard at a time to avoid using a lot of memory?
class ShardedDataParallel(nn.Module): def init(self, cpu_model: nn.Sequential, pg: ProcessGroup):
split the input model into shards based on config/profiling/etc.
self.model_shards = split(cpu_model, pg)
The nn.Sequential
here might not fit on GPU/CPU memory of a single host for very large models. Would we be using RemoteModule here? Also, this API is slightly different from deepspeed.initialize since I believe they require the user to pass in a sharded model.
Existing torch.distributed.DistributedDataParallel provides gradient bucketing and comm/comp overlapping features through the Reducer, which would be useful for DeepSpeed too, especially as DeepSpeed targets large models. Instead of implementing everything from scratch for DeepSpeed, we can try to integrate DeepSpeed into DDP.
I'm wondering if we should pull out the common building blocks here (bucketing and comm/comp overlap) and have two separate APIs instead. I feel that the DDP Reducer code is already pretty complex and augmenting it to support DeepSpeed might add a lot of different options that we need to think about while working on DDP. If we have two separate frameworks DeepSpeed and DDP that share some common building blocks, it might be easier to iterate on each one separately.
class ShardedOptimizer(torch.optim.Optimizer): def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: loss = self.optim.step(closure=closure) for rank, param_groups in enumerate(self.partition_parameters()): for param_group in param_groups: for param in param_group["params"]: dist.broadcast(param, rank, group=self.group) return loss
I'm not sure I understood why this implementation needs to broadcast all parameters to all ranks? Don't we only need to broadcast the parameters one shard at a time to avoid using a lot of memory?
@msbaines is the original author so he could probably have a say, but I don't get your point, could you elaborate ? In that case all the ranks have the full model, so every one needs to get the updated version of the params. The optimizer shard does its step, this shard has a new version of a chunk of the params but needs to be updated for all the other ones, broadcast (with varying sources) does that. It does take a lot of ram and I may be missing something, curious to hear your thoughts.
As discussed above with @mrshenli above this is only true of a sharded optimizer in a vacuum, if the rest of the ZeRO feature suite is there then it would need to be significantly rewritten
I'm not sure I understood why this implementation needs to broadcast all parameters to all ranks? Don't we only need to broadcast the parameters one shard at a time to avoid using a lot of memory?
dist.broadcast(param, rank, group=self.group)
broadcast is one-to-many. The rank
with the shard is broadcasting the param to the other ranks.
One simple optimization that we planned to do was use per-shard batch buffers to reduce broadcast calls. The current implementation is functional but not performant.
Ah. I misunderstood. I think you're asking about Pos+g+p. We've only implemented Pos (optimizer state sharding). With Pos+g+p the parameters are also sharded.
os - optimizer state g - gradient p - parameters
g requires reduce of gradients during backward p requires broadcast during forward
Is the model passed into ShardedDataParallel already a sharded version of the model where it is the user's responsibility to do this sharding? Looking at deepspeed.initialize, this does seem to be their intention.
This is open for discussion. Initially I wasn't sure whether ShardedDataParallel
should sit on top of RPC and RemoteModule. If there are no better options, then this can be the way to go.
I'm not sure I understood why this implementation needs to broadcast all parameters to all ranks? Don't we only need to broadcast the parameters one shard at a time to avoid using a lot of memory?
This is to allow DeepSpeed's optimizer states sharding algorithm to work with other non-DeepSpeed training techniques. Given DeepSpeed's blog post, the optimizer states can be a few times larger than the model state. So I guess there can be use cases where the model can fit in one GPU but optimizer states cannot. In such cases, applications can use full-replicated models + sharded optimizer. As @blefaudeux mentioned above, when all DeepSpeed's algorithms are in, we might need to rewrite the sharded optimizer. Actually, thinking about this again, with ShardedDataParallel
, regular local optimizers might be sufficient, as the forward pass will pull the updated parameters anyway?
The nn.Sequential here might not fit on GPU/CPU memory of a single host for very large models. Would we be using RemoteModule here? Also, this API is slightly different from deepspeed.initialize since I believe they require the user to pass in a sharded model.
Agree, but also would like to see if there are any other suggestions. The downside of using current RemoteModule
is that it requires applications to do manual model decompositions. I wonder if there is a better way that we can automate this process. E.g., I wonder if we can have things like un-materlized-module/lazy-modules, whose __init__
function just record some configurations instead of creating real tensors.
I'm wondering if we should pull out the common building blocks here (bucketing and comm/comp overlap) and have two separate APIs instead. I feel that the DDP Reducer code is already pretty complex and augmenting it to support DeepSpeed might add a lot of different options that we need to think about while working on DDP. If we have two separate frameworks DeepSpeed and DDP that share some common building blocks, it might be easier to iterate on each one separately.
Agree that current Reducer is already quite complex. I am thinking about creating a ReducerBase
where all the four components are pluggable. Will finalize this after we know concretely what's needed to support DeepSpeed.
@mrshenli @msbaines @blefaudeux Thanks for your comments regarding the optimizer sharding. I understand the reason for the broadcast now since I originally had Pos + g + p in mind, but as Mandeep said we've only implemented Pos so far.
Actually, thinking about this again, with ShardedDataParallel, regular local optimizers might be sufficient, as the forward pass will pull the updated parameters anyway?
Yes, I do feel it would be cleaner for the forward pass to pull the updated parameters instead of the optimizer broadcasting it upfront.
Agree, but also would like to see if there are any other suggestions. The downside of using current RemoteModule is that it requires applications to do manual model decompositions. I wonder if there is a better way that we can automate this process. E.g., I wonder if we can have things like un-materlized-module/lazy-modules, whose init function just record some configurations instead of creating real tensors.
I'm wondering if DeepSpeed currently supports some sort of automation here? My cursory look seems to suggest they expect the user to partition the model themselves. Given the scope of this project, I feel a good starting point would be to expect the user to manually partition the model.
ShardedOptimizer
works best with a DDP that does a reduce
instead of all_reduce
. So it might be simpler to make ShardedOptimizer
private and make ShardedDataParallel
take a Type[Optimizer]
and optimizer args and make it responsible for sharding the optimizer. That way the public API is a single class, ShardedDataParallel
that works like DistributedDataParallel except that you tell it how to construct the Optimizer and it requires an nn.Sequential
.
class ShardedDataParallel(nn.Module):
def __init__(self, cpu_model: nn.Sequential, pg: ProcessGroup, optim: Type[Optimizer], **optim_args):
A trivial first implementation could be a simple sub-class of DDP that still does all_reduce and pulls in some of the logic from oss.py. Then add the other enhancements (i.e. Pg + Pp).
Hey @msbaines, after adding #39272, looks like we might be able to use reduce
in the DDP comm hook, and then combine that with the ShardedOptimizer
?
Hey @msbaines, after adding #39272, looks like we might be able to use
reduce
in the DDP comm hook, and then combine that with theShardedOptimizer
?
Cool. That's a nice way to do reduce
. But how would you combine ShardedOptimizer? The buckets would have to be shard-aligned.
The buckets would have to be shard-aligned.
Good point! we at least need to allow custom bucketing to make it work.
The buckets would have to be shard-aligned.
Good point! we at least need to allow custom bucketing to make it work.
tipped by @pritamdamania, looks like this could be done via the state parameter, no ? See
edit: I missed the static grad buckets not being properly aligned, my bad
The buckets would have to be shard-aligned.
Good point! we at least need to allow custom bucketing to make it work.
@mrshenli wrt the custom bucketing, is there some public plan for that ? It would be super interesting for the lightning integration, we would love to get to a proper reduce
without having to fork the DDP engine. cc @ananthsub
edit: goes with https://github.com/pytorch/pytorch/issues/37002 I suppose
wrt the custom bucketing, is there some public plan for that ? It would be super interesting for the lightning integration, we would love to get to a proper reduce without having to fork the DDP engine.
Yep, we are reviewing a detailed implementation plan for #37002 internally. And should be able to share it on github in a few weeks.
Zero-3 has been released! https://news.ycombinator.com/item?id=26447018
Background
DeepSpeed reduces distributed data parallel training memory footprint by partitioning parameters, gradients, optimizer states, and recomputing activations. See this link for more details.
It takes care of the entire training iteration instead of exposing an
nn.Module
-like API.DeepSpeed1 provides optimizer states partitioning. DeepSpeed2 provides the full feature set.
API
To integrate DeepSpeed into PyTorch, it will be great if we can have a consistent API, i.e.,
model(inputs)
,loss.backward()
,optimizer.step()
. In this way, we can minimize the surprise to users and can minimize the code change when existing applications would like to adopt this technique. We can try to decompose DeepSpeed into the following two concepts, open to better names :)torch.distributed.nn.ShardedDataParallel
: comply with thetorch.nn.Module
API.torch.distributed.optim.ShardedOptimizer
: comply with thetorch.optim.Optimizer
API.And training loops can look like the following, which is almost the same as local training.
Upstream DeepSpeed to PyTorch
ShardedOptimizer
The optimizer part is relatively independent, and fairscale has already implemented one using PyTorch optimizer API. We can start the upstream effort from there. See
oss.py
ShardedDataParallel
I haven’t got into all details, but we might be able to wrap DeepSpeed algorithm as an
nn.Module
using custom autograd functions. The high-level idea is that we can split the input model into shards and then insert phony layers between shards to load and drop model parameters/grads accordingly.With the above building blocks
ShardedDataParallel
can assemble thePhonyLayer
andModelShard
objects in its forward pass.Integrate DeepSpeed into DDP
Existing
torch.distributed.DistributedDataParallel
provides gradient bucketing and comm/comp overlapping features through theReducer
, which would be useful for DeepSpeed too, especially as DeepSpeed targets large models. Instead of implementing everything from scratch for DeepSpeed, we can try to integrate DeepSpeed into DDP. The DeepSpeed optimizer can stay independent, and the rest of DeepSpeed algorithms can be decomposed to fit into theReducer
API. However, today’sReducer
API is not sufficient though, as it would materialize all buckets at construction time, install all autograd hooks at construction time, and always usesAllReduce
to communicate gradients. We will need the modularized DDP API .The current DDP algorithm can be decomposed and briefly represented using the following Python pseudo-code (although the current implementation is in C++). The four concepts (Grad Reader, Grad Bucketer, Comm Scheduler, and Grad Writer) in #37002 map to four functions below. The
grad_ready
function triggers DDP backward logics.To integrate DeepSpeed with DDP, we can override the
build_buckets
andbucket_ready
functions accordingly to take care of the gradient communication in the backward pass. Besides, as not all parameters will be materialized at construction time, it should leave out the invocation ofregister_grad_hooks
in ctor as well.The
PhonyLayer
would still be necessary to load and drop model shards. However, instead of gathering grads, it only needs to rungrad_ready
accordingly to triggerReducer
code.Then, the
SharedDataparallel
nn.Module
becomes:Implementation Plan
ShardedOptimizer
, as there are no road blockers as of today and fairscale already has an elegant implementation.Reducer
API is ready, we can move on to add parameter + gradients sharding.cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @xush6528 @osalpekar @jiayisuse @agolynski