huggingface / datatrove

Freeing data processing from scripting madness by providing a set of platform-agnostic customizable pipeline processing blocks.
Apache License 2.0
1.97k stars 139 forks source link

SLURM cannot achieve cross-node parallelism #292

Open ShayDuane opened 4 days ago

ShayDuane commented 4 days ago

I have a SLURM cluster with 50 nodes, each node having 96 CPU cores. I want to execute a job on the cluster, and the job is divided into 192 subtasks. Theoretically, I should be able to lock resources on two nodes to run these 192 tasks simultaneously. However, when I specified node41 and node42 in the nodelist, I found that parallel execution across nodes could not be achieved. Each node executed tasks with the same rank. Below are the logs:

+ export PYTHONUNBUFFERED=TRUE
+ PYTHONUNBUFFERED=TRUE
+ srun -l launch_pickled_pipeline /shared/home/dsq/test_log/executor.pik
1: 2024-09-25 18:42:05.729 | INFO     | vldata.utils.logging:add_task_logger:58 - Launching pipeline for rank=191 in node42
1: 2024-09-25 18:42:05.730 | INFO     | vldata.utils.logging:log_pipeline:90 - 
1: --- πŸ› οΈ PIPELINE πŸ› 
1: πŸ“– - READER: πŸ“· MINT
1: πŸ’½ - WRITER: πŸ“· MINTWriter
1: 2024-09-25 18:42:05.733 | INFO     | vldata.pipeline.readers.base:read_files_shard:193 - Reading input file part1/copy_191.tar, 1/1
0: 2024-09-25 18:42:06.148 | INFO     | vldata.utils.logging:add_task_logger:58 - Launching pipeline for rank=191 in node41
0: 2024-09-25 18:42:06.149 | INFO     | vldata.utils.logging:log_pipeline:90 - 
0: --- πŸ› οΈ PIPELINE πŸ› 
0: πŸ“– - READER: πŸ“· MINT
0: πŸ’½ - WRITER: πŸ“· MINTWriter
0: 2024-09-25 18:42:06.153 | INFO     | vldata.pipeline.readers.base:read_files_shard:193 - Reading input file part1/copy_191.tar, 1/1
0: 2024-09-25 18:45:58.083 | SUCCESS  | vldata.executor.base:_run_for_rank:98 - Processing done for rank=191
0: 2024-09-25 18:45:58.085 | INFO     | vldata.executor.base:_run_for_rank:104 - 
0: 
0: πŸ“‰πŸ“‰πŸ“‰ Stats: Task 191 πŸ“‰πŸ“‰πŸ“‰
0: 
0: Total Runtime: 3 minutes and 51 seconds
0: 
0: πŸ“– - READER: πŸ“· MINT
0:     Runtime: (99.28%) 3 minutes and 49 seconds [3 seconds and 149.18 millisecondsΒ±8 seconds and 461.30 milliseconds/doc]
0:     Stats: {input_files: 1, doc_len: 2684793 [min=2165, max=113362, 36777.99Β±18429/doc], documents: 72 [72.00/input_file]}
0: πŸ’½ - WRITER: πŸ“· MINTWriter
0:     Runtime: (0.72%) 1 second [22.87 millisecondsΒ±32.31 milliseconds/doc]
0:     Stats: {part1/copy_191.tar: 73, total: 73, doc_len: 2684793 [min=2165, max=113362, 36777.99Β±18429/doc]}
1: 2024-09-25 18:46:01.334 | SUCCESS  | vldata.executor.base:_run_for_rank:98 - Processing done for rank=191
1: 2024-09-25 18:46:01.340 | INFO     | vldata.executor.base:_run_for_rank:104 - 

Below are my code

parser = argparse.ArgumentParser(description="Read and Write example")
parser.add_argument("--input_folder", default="**", help="Input folder path")
parser.add_argument("--base_output_folder", default="**", help="Base output folder path")
parser.add_argument('--tasks', default=192, type=int,
                    help='total number of tasks to run the pipeline on (default: 1)')
parser.add_argument('--workers', default=-1, type=int,
                    help='how many tasks to run simultaneously. (default is -1 for no limit aka tasks)')
parser.add_argument('--limit', default=-1, type=int,
                    help='Number of files to process')
parser.add_argument('--logging_dir', default="**", type=str,
                    help='Path to the logging directory')
# parser.add_argument('--local_tasks', default=-1, type=int,
#                     help='how many of the total tasks should be run on this node/machine. -1 for all')
# parser.add_argument('--local_rank_offset', default=0, type=int,
#                     help='the rank of the first task to run on this machine.')
parser.add_argument('--job_name', default='**', type=str,
                    help='Name of the job')
parser.add_argument('--condaenv', default='vldata', type=str,
                    help='Name of the conda environment')
parser.add_argument('--slurm_logs_folder', default='**', type=str,
                    help='Path to the slurm logs folder')
parser.add_argument(
    '--nodelist', 
    type=str, 
    default='node41,node42', 
    help='Comma-separated list of nodes (default: node41 to node49)'
)
parser.add_argument('--nodes', default=2, type=str,
                    help='Number of nodes to use')
parser.add_argument('--time', default='01:00:00', type=str,
                    help='Time limit for the job')

parser.add_argument(
    '--exclude', 
    type=str, 
    help='List of nodes to exclude'
)

if __name__ == '__main__':
    args = parser.parse_args()

    sbatch_args = {}

    if args.nodelist:
        sbatch_args["nodelist"] = args.nodelist

    if args.exclude:
        sbatch_args["exclude"] = args.exclude
    if args.nodes:
        sbatch_args["nodes"] = args.nodes
    pipeline = [
        MINTReader(data_folder=args.input_folder, glob_pattern="*.tar", limit=args.limit),
        MINTWriter(output_folder=args.base_output_folder)
    ]

    executor = SlurmPipelineExecutor(pipeline=pipeline, 
                                    tasks=args.tasks,
                                    workers=args.workers,
                                    logging_dir=args.logging_dir,
                                    partition='cpu',
                                    sbatch_args=sbatch_args,
                                    condaenv=args.condaenv,
                                    time=args.time,
                                    job_name=args.job_name,
                                    slurm_logs_folder=args.slurm_logs_folder,
                                    )

    print(executor.run())

Below are my sbatch script:

#!/bin/bash
#SBATCH --cpus-per-task=1
#SBATCH --mem-per-cpu=2G
#SBATCH --partition=cpu
#SBATCH --job-name=****
#SBATCH --time=01:00:00
#SBATCH --output=***slurm_logs/%A_%a.out
#SBATCH --error=***/slurm_logs/%A_%a.out
#SBATCH --array=0-191
#SBATCH --nodelist=node41,node42
#SBATCH --nodes=2
#SBATCH --requeue
#SBATCH --qos=normal
echo "Starting data processing job ****"
conda init bash
conda activate vldata
source ~/.bashrc
set -xe
export PYTHONUNBUFFERED=TRUE
srun  -l launch_pickled_pipeline ****/executor.pik