pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.96k stars 22.63k forks source link

[RFC] DeepSpeed + PT Distributed Integration #42849

Open mrshenli opened 4 years ago

mrshenli commented 4 years ago

Background

DeepSpeed reduces distributed data parallel training memory footprint by partitioning parameters, gradients, optimizer states, and recomputing activations. See this link for more details.

It takes care of the entire training iteration instead of exposing an nn.Module-like API.

model_engine, optimizer, trainloader, _ = deepspeed.initialize(
    args=args, 
    model=net, 
    model_parameters=parameters, 
    training_data=trainset
)

for inputs in batches:
    outputs = model_engine(inputs)
    loss = criterion(outputs, labels)
    model_engine.backward(loss)
    model_engine.step()

DeepSpeed1 provides optimizer states partitioning. DeepSpeed2 provides the full feature set.

API

To integrate DeepSpeed into PyTorch, it will be great if we can have a consistent API, i.e., model(inputs), loss.backward(), optimizer.step(). In this way, we can minimize the surprise to users and can minimize the code change when existing applications would like to adopt this technique. We can try to decompose DeepSpeed into the following two concepts, open to better names :)

And training loops can look like the following, which is almost the same as local training.

import torch.optim as optim

dist.init_process_group(...)
model = torch.distributed.optim.ShardedDataParallel(model)
opt = torch.distributed.optim.ShardedOptimizer(optim.Adam, model.parameters())
# model.parameters() only returns the ones owned by the current rank

loss_fn(model(inputs), labels).backward()
opt.step()

Upstream DeepSpeed to PyTorch

ShardedOptimizer

The optimizer part is relatively independent, and fairscale has already implemented one using PyTorch optimizer API. We can start the upstream effort from there. See oss.py

class ShardedOptimizer(torch.optim.Optimizer):
    def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
        loss = self.optim.step(closure=closure)
        for rank, param_groups in enumerate(self.partition_parameters()):
            for param_group in param_groups:
                for param in param_group["params"]:
                    dist.broadcast(param, rank, group=self.group)
        return loss

ShardedDataParallel

I haven’t got into all details, but we might be able to wrap DeepSpeed algorithm as an nn.Module using custom autograd functions. The high-level idea is that we can split the input model into shards and then insert phony layers between shards to load and drop model parameters/grads accordingly.

# representing one shard of the model, which expose APIs to load/drop parameters
# and gather gradients
class ModelShard(nn.Module):

    def __init__(self, cpu_model_shard, owner_rank, pg):
        ...

    def forward_load(self):
        # materialize local GPU parameters, can enhance with bucketing
        futures = [pg.broadcast(p, owner_rank, async_op=True) 
                   for p in self.parameters()]
        # NB: this requires consolidating c10d work with torch.futures.Future
        torch.futures.wait_all(futures)

    ...

    def reduce_grads(self):
        futures = [pg.reduce(p, owner_rank, async_op=True)
                   for p in self.parameters()]
        torch.futures.wait_all(futures)

# The phony layer is a synchronization point between model shards.
# In the forward pass, it drops parameters in the previous shard and
# loads parameters for the next shard. In the backward pass, it does 
# the reverse and also gathers gradients to the owner.
# It does not change or create any outputs at all, instead it just
# forward the input as the output.
class PhonyLayer(Function)

    @staticmethod
    def forward(ctx, prev_shard, next_shard, *inputs)
        # check None accordingly
        prev_shard.forward_drop()
        next_shard.forward_load()
        ctx.prev_shard = prev_shard
        ctx.next_shard = next_shard
        return *inputs

    @staticmethod
    def backward(ctx, *grad_outputs):
        ctx.next_shard.reduce_grads()
        ctx.next_shard.backward_drop()
        ctx.prev_shard.backward_load()
        return *grad_outputs

With the above building blocks ShardedDataParallel can assemble the PhonyLayer and ModelShard objects in its forward pass.

class ShardedDataParallel(nn.Module):
    def __init__(self, cpu_model: nn.Sequential, pg: ProcessGroup):
        # split the input model into shards based on config/profiling/etc.
        self.model_shards = split(cpu_model, pg)

    def forward(self, *inputs):
        for prev, next in zip([None, *shards], [*shards, None]):
            inputs = prev(inputs) if prev else inputs
            inputs = PhonyLayer.apply(prev, next, *inputs)
        return inputs

Integrate DeepSpeed into DDP

Existing torch.distributed.DistributedDataParallel provides gradient bucketing and comm/comp overlapping features through the Reducer, which would be useful for DeepSpeed too, especially as DeepSpeed targets large models. Instead of implementing everything from scratch for DeepSpeed, we can try to integrate DeepSpeed into DDP. The DeepSpeed optimizer can stay independent, and the rest of DeepSpeed algorithms can be decomposed to fit into the Reducer API. However, today’s Reducer API is not sufficient though, as it would materialize all buckets at construction time, install all autograd hooks at construction time, and always uses AllReduce to communicate gradients. We will need the modularized DDP API .

The current DDP algorithm can be decomposed and briefly represented using the following Python pseudo-code (although the current implementation is in C++). The four concepts (Grad Reader, Grad Bucketer, Comm Scheduler, and Grad Writer) in #37002 map to four functions below. The grad_ready function triggers DDP backward logics.

class ReducerBase:
    def __init__(self, model, pg, *bucket_cap_mb* = 25):
        self.model, self.*bucket_cap_mb*, self.pg = model, *bucket_cap_mb*, pg
        self.grad2bucket = self.build_buckets()
        # let param.grad points to bucket offset (view)
        self.register_grad_hooks()

    def register_grad_hooks(self):
        for p in self.model.parameters():
            # register self.grad_ready to p's GradAccumulator as callback

    # Grad Bucketer
    def build_buckets(self):
        # map self.model.parameters() to a list of GradBucket objects

    # Grad Reader
    def grad_ready(self, grad, idx):
        bucket = self.grad2bucket[idx]
        bucket.pending_grad -= 1
        if self.bucket.pending_grad == 0:
            # only launch one bucket when all its preceding ones are 
            # launched, so that we won't mismatch allreduce comms
            for i in range(min_pending_bucket, bucket.idx + 1):
                if self.buckets[i].pending_grad == 0:
                    self.bucket_ready(self.buckets[i])

    # Comm Scheduler
    def bucket_ready(self, bucket):
        self.futures.append(
            self.pg.allreduce(bucket, async_op=True).then(self.comm_ready)
        )
        self.pending_bucket -= 1
        if self.pending_buckets == 0
            # the last bucket should block to make sure when backward()
            # returns, all grads are indeed ready
            torch.futures.wait_all(self.futures)

    # Grad Writer
    def comm_ready(self, comm_future):
        pass

To integrate DeepSpeed with DDP, we can override the build_buckets and bucket_ready functions accordingly to take care of the gradient communication in the backward pass. Besides, as not all parameters will be materialized at construction time, it should leave out the invocation of register_grad_hooks in ctor as well.

class ShardedReducer(RecuderBase):
    ...
    def build_buckets(self):
        # map self.model.parameters() to a list of GradBucket objects
        # only materialize the ones this rank owns. 

    # Comm Scheduler
    def bucket_ready(self, bucket):
        self.futures.append(
            self.pg.reduce(bucket, bucket_owner_map(bucket.idx), async_op=True)
        )
        self.model_shard.pending_bucket -= 1
        if self.model_shard.pending_buckets == 0
            torch.futures.wait_all(self.futures)

The PhonyLayer would still be necessary to load and drop model shards. However, instead of gathering grads, it only needs to run grad_ready accordingly to trigger Reducer code.

class ModelShard(nn.Module):
    def __init__(self, cpu_model_shard, owner_rank, pg, reducer):
        ...
        self.reducer = reducer

    def register_grad_reader(self):
        # this shard must have been materialized when reaching here
        for p in self.parameters():
            # register self.reducer.grad_ready to p's GradAccumulator as callback

# this is the same as the above PhonyLayer except it no longer
# communicates grad on its own, instead it registers grad hooks 
# at dynamically loaded parameters in the backward pass.
class PhonyLayer(Function)

    @staticmethod
    def forward(ctx, prev_shard, next_shard, *inputs)
        # check None accordingly
        prev_shard.forward_drop()
        next_shard.forward_load()
        ctx.prev_shard = prev_shard
        ctx.next_shard = next_shard
        return *inputs

    @staticmethod
    def backward(ctx, *grad_outputs):
        # ctx.next_shard.reduce_grads()
        # check None accordingly
        ctx.next_shard.backward_drop()
        ctx.prev_shard.backward_load()
        ctx.prev_shard.register_grad_reader()
        return *grad_outputs

Then, the SharedDataparallel nn.Module becomes:

class ShardedDataParallel(nn.Module):

    def __init__(self, cpu_model: nn.Sequential, pg: ProcessGroup):
        # Split the input model into shards based on config/profiling/etc.
        # Let each model shard keep a copy of the reducer, so that it can 
        # call reducer.grad_ready accordingly
        self.reducer = ShardedReducer(cpu_model, pg)
        self.shards = split(cpu_model, pg, self.redcuer)

    def forward(self, *inputs):
        # same as the above one

Implementation Plan

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @xush6528 @osalpekar @jiayisuse @agolynski

blefaudeux commented 4 years ago

Adding two cents on ShardedOptimizer, since I've been working on something similar in fairscale: not all concepts from the upstream pytorch optimizer translate 1:1, since there are N+1 states basically, versus 1 state. Each rank has its own state, and there's a virtual, distributed state. Some of the pytorch optimizer interfaces implicitly assume that there's a single state, for instance the .param_groups attribute, or state_dict().

It's not just theoretical, because some frameworks (ClassyVision for instance, but probably others) have taken advantage of these interfaces and are actively using them. Checkpointing becomes an issue if you pull state_dict() from a given replica and assume that you have the full state, or changing LRs per layer can also be an issue if you're manipulating the param_groups from a rank and it does not have your layer.

I don't think that all of this is blocking per say, there are more or less elegant ways to make that work, but I'm just adding a 1.1 phase basically : identify the new APIs really needed and agree on something which works well for everyone. Curious to hear your thoughts !

vreis commented 4 years ago

Regarding ShardedOptimizer and the OSS implementation in fairscale: another big change in the API is that users are expected to call consolidate_state_dict on the optimizer at the end of every epoch. We are able to hide this in Classy Vision cause our optimizer API is already aware of epochs (for param scheduling), but for general PyTorch users this is another extra step.

blefaudeux commented 4 years ago

Regarding ShardedOptimizer and the OSS implementation in fairscale: another big change in the API is that users are expected to call consolidate_state_dict on the optimizer at the end of every epoch. We are able to hide this in Classy Vision cause our optimizer API is already aware of epochs (for param scheduling), but for general PyTorch users this is another extra step.

Looping back other discussions here: one solution for this particular .state_dict() conformance issue could be to use RRef from torch RPC, so that the full state could materialize on whatever rank is pulling it. A short term blocker for this is that RPC and DDP need two separate init for now (see #41614) and that's cumbersome, but that may not be a fundamental blocker.

mrshenli commented 4 years ago

Some of the pytorch optimizer interfaces implicitly assume that there's a single state, for instance the .param_groups attribute, or state_dict().

For .param_groups, do we still need the broadcast op in the optimizer if we have ShardedDataParallel? Looks like SharedDataParallel can work with local optimizers and the local optimizer does not need to do broadcast anymore as the next forward pass can pull updated parameters?

For state_dict(), yep, RRef should be able to support this use case if the requirement is to collect all model states to one process. I wonder if ProcessGroup::send/recv would be sufficient as well? We can, e.g., specify a root rank that all processes send their model states to that root. Compared to the RRef solution, the difference would be that all processes need to call state_dict() and only the root passes it to torch.save, where as the RRef solution only requires the root calling state_dict().

blefaudeux commented 4 years ago

Some of the pytorch optimizer interfaces implicitly assume that there's a single state, for instance the .param_groups attribute, or state_dict().

For .param_groups, do we still need the broadcast op in the optimizer if we have ShardedDataParallel? Looks like SharedDataParallel can work with local optimizers and the local optimizer does not need to do broadcast anymore as the next forward pass can pull updated parameters?

No you're right, this is being worked on actually, I think fairscale's implementation can be improved indeed

For state_dict(), yep, RRef should be able to support this use case if the requirement is to collect all model states to one process. I wonder if ProcessGroup::send/recv would be sufficient as well?

Right now NCCL does not support that unfortunately, I think that this is being changed though

mrshenli commented 4 years ago

Right now NCCL does not support that unfortunately, I think that this is being changed though

Yep, code is available in ProcessGroupNCCL but disabled for OSS as we are still trying to figure out a weird error in CI test (#39984 #42514).

Regarding the API, if this has to be built on top of RPC and RRef, instead of passing in a cpu_model_shard, we can use RemoteModule, which is basically an RRef to a module on a different process.

https://github.com/pytorch/pytorch/blob/9600ed9af3b84c000b7f54765495e96f29c4bf1d/torch/distributed/nn/api/remote_module.py#L33-L41

Besides the concern of having both init_process_group and init_rpc (I am confident we can fix that), is there any other downside of letting Sharded* sits on top of both RPC and c10d? Or should we just build it on top of RPC, as we can implement gather/broadcast on RPC as well?

blefaudeux commented 4 years ago

Right now NCCL does not support that unfortunately, I think that this is being changed though

Yep, code is available in ProcessGroupNCCL but disabled for OSS as we are still trying to figure out a weird error in CI test (#39984 #42514).

Regarding the API, if this has to be built on top of RPC and RRef, instead of passing in a cpu_model_shard, we can use RemoteModule, which is basically an RRef to a module on a different process.

https://github.com/pytorch/pytorch/blob/9600ed9af3b84c000b7f54765495e96f29c4bf1d/torch/distributed/nn/api/remote_module.py#L33-L41

Besides the concern of having both init_process_group and init_rpc (I am confident we can fix that), is there any other downside of letting Sharded* sits on top of both RPC and c10d? Or should we just build it on top of RPC, as we can implement gather/broadcast on RPC as well?

I think it makes sense, it's a significant departure from the existing fairscale implementation (I'm not sure about deepspeed), but longer term it's the nicer solution IMO. I don't know of any significant drawback and it makes it possible to target a drop-in solution (same interface as non-sharded) I think

pritamdamania87 commented 4 years ago

model = torch.distributed.optim.ShardedDataParallel(model)

Is the model passed into ShardedDataParallel already a sharded version of the model where it is the user's responsibility to do this sharding? Looking at deepspeed.initialize, this does seem to be their intention.

class ShardedOptimizer(torch.optim.Optimizer): def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: loss = self.optim.step(closure=closure) for rank, param_groups in enumerate(self.partition_parameters()): for param_group in param_groups: for param in param_group["params"]: dist.broadcast(param, rank, group=self.group) return loss

I'm not sure I understood why this implementation needs to broadcast all parameters to all ranks? Don't we only need to broadcast the parameters one shard at a time to avoid using a lot of memory?

class ShardedDataParallel(nn.Module): def init(self, cpu_model: nn.Sequential, pg: ProcessGroup):

split the input model into shards based on config/profiling/etc.

self.model_shards = split(cpu_model, pg)

The nn.Sequential here might not fit on GPU/CPU memory of a single host for very large models. Would we be using RemoteModule here? Also, this API is slightly different from deepspeed.initialize since I believe they require the user to pass in a sharded model.

Existing torch.distributed.DistributedDataParallel provides gradient bucketing and comm/comp overlapping features through the Reducer, which would be useful for DeepSpeed too, especially as DeepSpeed targets large models. Instead of implementing everything from scratch for DeepSpeed, we can try to integrate DeepSpeed into DDP.

I'm wondering if we should pull out the common building blocks here (bucketing and comm/comp overlap) and have two separate APIs instead. I feel that the DDP Reducer code is already pretty complex and augmenting it to support DeepSpeed might add a lot of different options that we need to think about while working on DDP. If we have two separate frameworks DeepSpeed and DDP that share some common building blocks, it might be easier to iterate on each one separately.

blefaudeux commented 4 years ago

class ShardedOptimizer(torch.optim.Optimizer): def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: loss = self.optim.step(closure=closure) for rank, param_groups in enumerate(self.partition_parameters()): for param_group in param_groups: for param in param_group["params"]: dist.broadcast(param, rank, group=self.group) return loss

I'm not sure I understood why this implementation needs to broadcast all parameters to all ranks? Don't we only need to broadcast the parameters one shard at a time to avoid using a lot of memory?

@msbaines is the original author so he could probably have a say, but I don't get your point, could you elaborate ? In that case all the ranks have the full model, so every one needs to get the updated version of the params. The optimizer shard does its step, this shard has a new version of a chunk of the params but needs to be updated for all the other ones, broadcast (with varying sources) does that. It does take a lot of ram and I may be missing something, curious to hear your thoughts.

As discussed above with @mrshenli above this is only true of a sharded optimizer in a vacuum, if the rest of the ZeRO feature suite is there then it would need to be significantly rewritten

msbaines commented 4 years ago

I'm not sure I understood why this implementation needs to broadcast all parameters to all ranks? Don't we only need to broadcast the parameters one shard at a time to avoid using a lot of memory?

dist.broadcast(param, rank, group=self.group)

broadcast is one-to-many. The rank with the shard is broadcasting the param to the other ranks.

One simple optimization that we planned to do was use per-shard batch buffers to reduce broadcast calls. The current implementation is functional but not performant.

Ah. I misunderstood. I think you're asking about Pos+g+p. We've only implemented Pos (optimizer state sharding). With Pos+g+p the parameters are also sharded.

os - optimizer state g - gradient p - parameters

g requires reduce of gradients during backward p requires broadcast during forward

mrshenli commented 4 years ago

Is the model passed into ShardedDataParallel already a sharded version of the model where it is the user's responsibility to do this sharding? Looking at deepspeed.initialize, this does seem to be their intention.

This is open for discussion. Initially I wasn't sure whether ShardedDataParallel should sit on top of RPC and RemoteModule. If there are no better options, then this can be the way to go.

I'm not sure I understood why this implementation needs to broadcast all parameters to all ranks? Don't we only need to broadcast the parameters one shard at a time to avoid using a lot of memory?

This is to allow DeepSpeed's optimizer states sharding algorithm to work with other non-DeepSpeed training techniques. Given DeepSpeed's blog post, the optimizer states can be a few times larger than the model state. So I guess there can be use cases where the model can fit in one GPU but optimizer states cannot. In such cases, applications can use full-replicated models + sharded optimizer. As @blefaudeux mentioned above, when all DeepSpeed's algorithms are in, we might need to rewrite the sharded optimizer. Actually, thinking about this again, with ShardedDataParallel, regular local optimizers might be sufficient, as the forward pass will pull the updated parameters anyway?

The nn.Sequential here might not fit on GPU/CPU memory of a single host for very large models. Would we be using RemoteModule here? Also, this API is slightly different from deepspeed.initialize since I believe they require the user to pass in a sharded model.

Agree, but also would like to see if there are any other suggestions. The downside of using current RemoteModule is that it requires applications to do manual model decompositions. I wonder if there is a better way that we can automate this process. E.g., I wonder if we can have things like un-materlized-module/lazy-modules, whose __init__ function just record some configurations instead of creating real tensors.

I'm wondering if we should pull out the common building blocks here (bucketing and comm/comp overlap) and have two separate APIs instead. I feel that the DDP Reducer code is already pretty complex and augmenting it to support DeepSpeed might add a lot of different options that we need to think about while working on DDP. If we have two separate frameworks DeepSpeed and DDP that share some common building blocks, it might be easier to iterate on each one separately.

Agree that current Reducer is already quite complex. I am thinking about creating a ReducerBase where all the four components are pluggable. Will finalize this after we know concretely what's needed to support DeepSpeed.

pritamdamania87 commented 4 years ago

@mrshenli @msbaines @blefaudeux Thanks for your comments regarding the optimizer sharding. I understand the reason for the broadcast now since I originally had Pos + g + p in mind, but as Mandeep said we've only implemented Pos so far.

Actually, thinking about this again, with ShardedDataParallel, regular local optimizers might be sufficient, as the forward pass will pull the updated parameters anyway?

Yes, I do feel it would be cleaner for the forward pass to pull the updated parameters instead of the optimizer broadcasting it upfront.

Agree, but also would like to see if there are any other suggestions. The downside of using current RemoteModule is that it requires applications to do manual model decompositions. I wonder if there is a better way that we can automate this process. E.g., I wonder if we can have things like un-materlized-module/lazy-modules, whose init function just record some configurations instead of creating real tensors.

I'm wondering if DeepSpeed currently supports some sort of automation here? My cursory look seems to suggest they expect the user to partition the model themselves. Given the scope of this project, I feel a good starting point would be to expect the user to manually partition the model.

msbaines commented 4 years ago

ShardedOptimizer works best with a DDP that does a reduce instead of all_reduce. So it might be simpler to make ShardedOptimizer private and make ShardedDataParallel take a Type[Optimizer] and optimizer args and make it responsible for sharding the optimizer. That way the public API is a single class, ShardedDataParallelthat works like DistributedDataParallel except that you tell it how to construct the Optimizer and it requires an nn.Sequential.

class ShardedDataParallel(nn.Module):
    def __init__(self, cpu_model: nn.Sequential, pg: ProcessGroup, optim: Type[Optimizer], **optim_args):

A trivial first implementation could be a simple sub-class of DDP that still does all_reduce and pulls in some of the logic from oss.py. Then add the other enhancements (i.e. Pg + Pp).

mrshenli commented 4 years ago

Hey @msbaines, after adding #39272, looks like we might be able to use reduce in the DDP comm hook, and then combine that with the ShardedOptimizer?

msbaines commented 4 years ago

Hey @msbaines, after adding #39272, looks like we might be able to use reduce in the DDP comm hook, and then combine that with the ShardedOptimizer?

Cool. That's a nice way to do reduce. But how would you combine ShardedOptimizer? The buckets would have to be shard-aligned.

mrshenli commented 4 years ago

The buckets would have to be shard-aligned.

Good point! we at least need to allow custom bucketing to make it work.

blefaudeux commented 4 years ago

The buckets would have to be shard-aligned.

Good point! we at least need to allow custom bucketing to make it work.

tipped by @pritamdamania, looks like this could be done via the state parameter, no ? See

edit: I missed the static grad buckets not being properly aligned, my bad

blefaudeux commented 4 years ago

The buckets would have to be shard-aligned.

Good point! we at least need to allow custom bucketing to make it work.

@mrshenli wrt the custom bucketing, is there some public plan for that ? It would be super interesting for the lightning integration, we would love to get to a proper reduce without having to fork the DDP engine. cc @ananthsub

edit: goes with https://github.com/pytorch/pytorch/issues/37002 I suppose

mrshenli commented 4 years ago

wrt the custom bucketing, is there some public plan for that ? It would be super interesting for the lightning integration, we would love to get to a proper reduce without having to fork the DDP engine.

Yep, we are reviewing a detailed implementation plan for #37002 internally. And should be able to share it on github in a few weeks.

LifeIsStrange commented 3 years ago

Zero-3 has been released! https://news.ycombinator.com/item?id=26447018