In this project, we propose to use pipeline parallelism for large-scale language modeling. Our contributions include:
See cluster/README.md to set up a cluster for developing and testing. After that, clone the repo to the NFS directory ~/efs
shared by all nodes:
cd ~/efs
git clone https://github.com/zhuohan123/model-parallel-speed-test.git
See MODEL_CONFIGS
dictionary in transformer_models.py for the list of the models we are testing on.
Pipelining on sequence length dimension on all GPUs in the node:
# number of nodes, number of gpus per node, model parallel size,
# pipeline parallel size, model name, number of slices, number of steps
N_NODES=1 # Number of nodes in the cluster
N_GPUS=1 # Number of GPUs per node
MODEL_PARALLEL_SIZE=1 # Number of devices in a single model parallel (parallel matmul) groups
PIPELINE_PARALLEL_SIZE=1 # Number of stages for pipelining.
# Note that $N_NODES * $N_GPUS == $MODEL_PARALLEL_SIZE * $PIPELINE_PARALLEL_SIZE
MODEL=test # Name of the model to test (see MODEL_CONFIGS)
N_SLICES=8 # Number of input shards (currently we uniformly slice the input)
N_STEPS=10 # Number of testing steps to run
EXTRA_ARGS="--mixed-precision"
./mpirun_terapipe.sh $N_NODES $N_GPUS $MODEL_PARALLEL_SIZE $PIPELINE_PARALLEL_SIZE $MODEL $N_SLICES $N_STEPS $EXTRA_ARGS
Edit auto_latency_benchmark.sh
and add your model for computation latency evaluation.
Run ./auto_latency_benchmark.sh
over 1 p3.16xlarge machine.
Outputs in performance_model_data
.
Edit p2p_comm_latency.py.py
and add your model for communication latency evaluation.
Run ./p2p_comm_latency.sh
over 2 p3.16xlarge machines.
Outputs in performance_model_data
.
Edit and run latency_model.py
to generate the optimal slices with DP. Results are saved in dp_results.json
.
Edit and run auto_mpirun_dp_slices_evaluation.sh
. Results under dp_evaluation_results
.
Get the IPs of all the worker nodes in the cluster:
python scripts/get_worker_ips.py
Load $MY_IPADDR
, $OTHERS_IPADDR
, $ALL_IPADDR
as environment variables:
source scripts/load_cluster_env.sh
Run the same command on all nodes (useful for killing processes and check states):
scripts/fornode pkill python
scripts/fornode nvidia-smi