prime (previously called ZeroBand) is a framework for efficient, globally distributed training of AI models over the internet.
https://github.com/user-attachments/assets/c034d2a2-400c-4bf8-acd0-c84b6c897d69
ElasticDeviceMesh
for Fault Tolerant Training:
ElasticDeviceMesh
which encapsulates dynamic global process groups for fault-tolerant communication across the internet and local process groups for communication within a node or datacenter.ElasticDeviceMesh
manages the resizing of the global process groups when nodes join or leave, unlike the standard DeviceMesh
in torch distributed, which will crash and require a cold restart to resize the process group./dev/shm
which is a RAM backed filesystem. This operation is much faster and we can unblock the main training process once the checkpoint has been created in /dev/shm
./dev/shm
into the checkpoint directory on disk as well as upload it to the remote./dev/shm
.quantize_per_tensor
, scatter_add
, index
, etc) was too slow, resulting in underutilisation of our target network bandwidth of 4 Gbps.fully_shard
API from PyTorch FSDP2 which wraps the model parameters as DTensor
s and registers hooks to schedule all-gather and reduce-scatter on the tensors when they are used. FSDP2 also optimizes the collectives by bucketing the parameters into FSDPParamGroup
s. This allows us to execute the collectives on larger tensors, improving protocol-to-payload ratio and improving the overlap from pipelining. We employ the same trick for our pseudo-gradients, bucketing them by layer.A research paper about the framework and our INTELLECT-1 10B experiment is coming soon.
uv
:curl -LsSf https://astral.sh/uv/install.sh | sh
source $HOME/.cargo/env
Set up the environment:
uv venv
source .venv/bin/activate
uv sync --extra all
uv pip install flash-attn --no-build-isolation
git submodule update --init --recursive
Log into Hugging Face: prime uses gated models tokenizers mistralai/Mistral-7B-v0.1, meta-llama/Llama-2-7b-hf and pulls the C4:en dataset subset. It is required to request access to the models then log into Hugging Face with a read token to begin training.
huggingface-cli login
Verify your setup:
ZERO_BAND_LOG_LEVEL=DEBUG torchrun --nproc_per_node=2 src/zeroband/train.py @configs/debug/normal.toml
To test DiLoCo locally you can use the helper script scripts/simulatsimulate_multi_nodee_mutl.sh
# Using 4 GPUs
ZERO_BAND_LOG_LEVEL=DEBUG ./scripts/simulate_multi_node_diloco.sh 2 2 src/zeroband/train.py @configs/debug/diloco.toml
# Using 2 GPUs
ZERO_BAND_LOG_LEVEL=DEBUG ./scripts/simulate_multi_node_diloco.sh 2 1 src/zeroband/train.py @configs/debug/diloco.toml
Note: Single GPU setups are currently not supported due to an FSDP implementation bug.
Ensure you have at least two GPU to run the full test suite:
uv run pytest
Environment Variable | Description | Default Value |
---|---|---|
GLOBAL_UNIQUE_ID |
Unique identifier worker in global store. | None |
GLOBAL_ADDR |
IP Address of the global store | None |
GLOBAL_PORT |
Port number of the global store. | None |
GLOBAL_WORLD_SIZE |
The size of the global process group. | 1 |
GLOBAL_RANK |
Rank of the process in the global process group. | 0 |
Environment Variable | Description | Default Value |
---|---|---|
ZERO_BAND_LOG_LEVEL |
Enable debug mode for loge | False |
ZERO_BAND_GLOBAL_STORE_TIMEOUT_SECONDS |
Number of seconds before the global store operations timeout | 300 |
ZERO_BAND_GLOBAL_PG_TIMEOUT_SECONDS |
Number of seconds before the global process group operations timeout | 600 |
ZERO_BAND_GLOBAL_STORE_POLLING_INTERVAL_SECONDS |
Number of seconds between polls to the store when waiting for values | 0.1 |
ZERO_BAND_EDM_HEARTBEAT_INTERVAL_SECONDS |
Interval in seconds between heartbeats | 2 |
ZERO_BAND_EDM_HEARTBEAT_TIMEOUT_SECONDS |
Time in seconds after which a node is considered dead if no heartbeat is received | 10 |
ZERO_BAND_LIVE_RECO_PORT |
Port number for the live recovery server | random |
ZERO_BAND_LIVE_RECO_ADDR |
IP Address for the live recovery server | localhost |
If you encounter any dataset loading errors at the beginning of training, try setting:
export HF_HUB_ETAG_TIMEOUT=500
Streaming datasets from huggingface hub can sometimes result in http 443 errors which will crash the training process. To avoid them, you can pre-download the dataset.
Here is an example that downloads all the files in PrimeIntellect/fineweb-edu
which are used by data_rank
5 in a training with data_world_size
of 12.
python3 scripts/subset_data.py --dataset_name PrimeIntellect/fineweb-edu --data_world_size 12 --data_rank 5
For info about the arguments to the script, do:
python3 scripts/subset_data.py --help