pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
2.65k stars 206 forks source link

Fail-safe and partial redundancy for HSDP on unreliable compute #561

Open evkogs opened 2 months ago

evkogs commented 2 months ago

I'd like to propose a feature for implementing fail-safe mechanisms and partial redundancy in FSDP2 (possibly not FSDP already, more like HSDP) to allow for more robust training on unreliable compute resources, such as cloud spot instances. The main goal is to make training more resilient to node failures, GPU issues, and other potential interruptions.

Key points:

  1. Implement an abstraction over DDP and FSDP with configurable parameters for redundancy.
  2. Allow for partial redundancy, similar to RAID5 or RAID6 concepts, where full redundancy would be equivalent to DDP and zero redundancy would be equivalent to FSDP full-shard or Zero-3.
  3. Mitigate node failures and individual GPU failures by storing additional fractions (e.g., 1/8 or 1/4) of other nodes' optimizer states on each node.
  4. Trade-off between memory usage and all-reduce overhead (estimated 10-20%) for increased training resilience.
  5. Implement automatic downscaling with resharding and upscaling with automatic sharding, with a configurable overlapping sharding parameter (0.0 to 1.0).

Use case examples:

  1. Training on cloud spot instances that may be terminated mid-training.
  2. Giant model training on 99.9% reliable hardware, protecting against network adapter failures, power outages, etc.
  3. Enabling cross-regional model training on spot instances or multi-region clusters for colossal models.
  4. Supporting distributed training methods like DisTrO (https://github.com/NousResearch/DisTrO) that allow training over the internet with much lower throughput requirements than traditional all-reduce approach.

This feature would greatly enhance the flexibility and reliability of large-scale distributed training, especially in scenarios where compute resources are not guaranteed to be stable throughout the entire training process.

A key aspect of this implementation would be an overlapping factor, ranging from 0.0 to 1.0, which determines the degree of redundancy. For example, with 64 GPUs across 8 nodes:

The system would need to integrate downscaling with resharding and automatic restoring, as well as upscaling with automatic sharding, all governed by this specified overlapping factor (probably using Kubernetes with torchx, for example).

I'd be happy to discuss this further and provide more details if needed! Looking forward to your thoughts on this proposal!

tianyu-l commented 2 months ago

@awgu @wconstab @fegin

evkogs commented 2 months ago

I see it mainly as a complementary addition to the existing torch.distributed.elastic functionality.

Also, considering numerous ways to launch a training job, the main functionality would be restoring all model weights, activations, and optimizer states to a smaller number of workers (scale down).

In the case of a specified launcher e.g. torchrun or torchx with Kubernetes scheduler, there's also an option to fully manage the cluster and replace workers (both scale up and down).

Also, for clusters of thousands of GPUs, overhead won't be significant: for 64-128 or more nodes, the desired overlapping factor might be 2.5% - 5% to guarantee resilience to outages, which is a small cost.

jiamings commented 2 months ago

This is actually a great idea -- as ECC error is quite common in HBMs this can help us to not have to restart the entire job when we encounter a single ECC error. But not sure how well this works with distributed checkpointing.

wconstab commented 2 months ago

Thanks for this proposal @evkogs! We would need to get more specific about a design to say for sure, but I think there are largely 2 issues that need to be addressed before this could be feasible.

1) How can we drop some members out of a communicator and add new ones when the scheduler replaces them (e.g. PyTorch ProcessGroupNCCL + nccl communicator)? Today, the only way to do this is to tear down the 'world' and create a new 'world'. This can be expensive, and requires coordination. 2) What is the right abstraction boundary between pytorch and the scheduler? We probably do not want to build all of this logic into pytorch as some of it ties into the job scheduling layer. Can we come up with a clear abstraction and propose which behaviors pytorch should implement and which ones should be provided by the scheduler itself?

evkogs commented 2 months ago

Thanks @wconstab ! 1) How can we drop some members out of a communicator and add new ones when the scheduler replaces them (e.g. PyTorch ProcessGroupNCCL + nccl communicator)? Today, the only way to do this is to tear down the 'world' and create a new 'world'. This can be expensive and requires coordination.

Well, I don't think that's an issue as it would be an infrequent event, at most 2-3 times for many nodes in an unreliable setup. So I think the current way would be absolutely fine for real-world cases. From torch.distributed.elastic docs:

Membership Changes Node departure (scale-down): The agent is notified of the departure, all existing workers are stopped, a new WorkerGroup is formed, and all workers are started with a new RANK and WORLD_SIZE. Node arrival (scale-up): The new node is admitted to the job, all existing workers are stopped, a new WorkerGroup is formed, and all workers are started with a new RANK and WORLD_SIZE.

2)

What is the right abstraction boundary between pytorch and the scheduler? We probably do not want to build all of this logic into pytorch as some of it ties into the job scheduling layer. Can we come up with a clear abstraction and propose which behaviors pytorch should implement and which ones should be provided by the scheduler itself?

That's a very good question! I think there's a place for a unified approach, combining all existing ones. Also, I was curious to look into pytorch h2 2024, and saw there are plans to integrate up to 5D model parallelism (whatever this means), so it might get even trickier soon. I feel if we continue to grow number of abstractions, it won't end well.

d4l3k commented 2 days ago

@evkogs, @jiamings just for visibility I've been working on getting something like this working for FSDP/torchtitan. It's still early and isn't compatible quite yet but a lot of the pieces are in place

https://github.com/pytorch-labs/torchft

GrigoryEvko commented 2 days ago

@evkogs, @jiamings just for visibility I've been working on getting something like this working for FSDP/torchtitan. It's still early and isn't compatible quite yet but a lot of the pieces are in place

https://github.com/pytorch-labs/torchft

Thanks a lot! Yep, looks like something that would do the task!