wandb / wandb

The AI developer platform. Use Weights & Biases to train and fine-tune models, and manage models from experimentation to production.
https://wandb.ai
MIT License
9.13k stars 671 forks source link

[Q] Using WandB Sweep + SLURM + Pytorch Lightning DDP + Multiple GPUs #5695

Open OFSkean opened 1 year ago

OFSkean commented 1 year ago

I'm trying to register SLURM nodes as agents for sweeps. I'm using Pytorch Lightning with DDP and multiple GPUs. Following the recommendations from Pytorch Lightning (here), my SLURM sbatch script is something like below.

#!/bin/bash

#SBATCH --gres=gpu:4
#SBATCH --cpus-per-task=8
#SBATCH --ntasks=4
#SBATCH --ntasks-per-node=4
#SBATCH --nodes=1

sweep_id=***********"
api_key="***********"

# Method 1: 
# Creates 4 processes which are bound to run 1 job on 4 gpus
# cant receive sweep configs from wandb so must manually change configs (aka bad)

srun python3 myprogram.py ${static_notsweep_args}

# Method 2:
# My attempt at registering this node as a wandb agent with access to 4 gpus
# Instead creates 4 agents each with access to 1 gpu

srun  wandb login $api_key
srun  wandb agent $sweep_id

My jobs require multiple gpus (4 in this example) to run. Note that Pytorch Lightning requires running commands with srun and setting ntasks=ngpus when using multiple gpus. These combined cause the srun lines to be run ntasks times, essentially creating ntasks processes running in parallel.

The problem I'm having with sweeping is since wandb agent $sweep_id gets ran ntasks times, it creates ntasks agents each running a separate configuration from the sweep. Furthermore, this causes Lightning DDP to not bind them together which restricts each agent to only having 1 GPU. This would actually be fine if 1 GPU per agent was enough, but I need all gpus.

There are some potential solutions I thought of, but they have their downsides:

  1. Setting ntasks=1 and using ddp_spawn for the Lightning trainer strategy. This way wandb agent only gets called once, and the appropriate number of processes are spawned for training. The downside is ddp_spawn is widely discouraged for performance reasons.

  2. Switching from CLI to python wandb, and registering the agent from inside myprogram.py. The python wandb seems more flexible, but I haven't used it so I don't actually know if this would work. The downside is I'd prefer to stick to wandb CLI.

The behavior I'd like to see is wandb agent $sweep_id running the same program + hyperparameters in parallel ntasks times, so that Lightning DDP can bind them together and use all gpus. I'm wondering if there is a way to accomplish this with the wandb CLI, for example like wandb agent <agent_id> so that the multiple calls to wandb agent get linked to the same agent.

nate-wandb commented 1 year ago

Hi @OFSkean, my first instinct would be to only use a single task and use Trainer(accelerator="gpu", devices=4, strategy="ddp") in your train.py to spin up the parallel GPU processes. You will need to put wandb.init() in an if block like this:

if __name__ == "__main__":
    # Get args
    args = parse_args()

    if args.local_rank == 0:  # only on main process
        # Initialize wandb run
        run = wandb.init(
            entity=args.entity,
            project=args.project,
        )
        # Train model with DDP
        train(args, run)
    else:
        train(args)

So that you can only create a run in the rank 0 process.

I don't think we have a complete example of Slurm + Sweep + DDP but happy to work through this if this doesn't work.

OFSkean commented 1 year ago

Hi @nate-wandb,

I tried setting ntasks=1, and while that solves the problem of wandb agent being called too much, it causes issues with Pytorch Lightning. Per this doc, when using SLURM with Lightning, ntasks must equal the number of devices.

There are two parametres in the SLURM submission script that determine how many processes will run your training, the #SBATCH --nodes=Xsetting and #SBATCH --ntasks-per-node=Y settings. The numbers there need to match what is configured in your Trainer in the code: Trainer(num_nodes=X, devices=Y). If you change the numbers, update them in BOTH places.

I don't know enough about Lightning to know why that's required when using SLURM. But pretty much I have to keep ntasks=4.

OFSkean commented 1 year ago

Ok I figured out a way to do this. It's really ugly but it works.

Sweep Configuration

project: slurm_test
name: my_sweep 
program: main.py

command:
  - ${env}
  - echo    **this line is different than usual**
  - python3
  - ${program}
  - ${args}

SLURM sbatch script

#!/bin/bash

#NUMBER OF AGENTS TO REGISTER AS WANDB AGENTS
#SHOULD BE -array=1-X, where X is number of estimated runs
#SBATCH --array=1-4    #e.g. 1-4 will create agents labeled 1,2,3,4

#SBATCH --gres=gpu:4
#SBATCH --cpus-per-task=8
#SBATCH --ntasks=4 # must equal number of gpus, as required by Lightning

.... other SLURM configuration like partition, time, etc ....
.... module purge and start conda environment if needed .....

# SET SWEEP_ID HERE. Note sweep must already be created on wandb before submitting job
SWEEP_ID="**************************************"
API_KEY="******************************************"

# LOGIN IN ALL TASKS
srun wandb login $api_key

# adapted from https://stackoverflow.com/questions/11027679/capture-stdout-and-stderr-into-different-variables
# RUN WANDB AGENT IN ONE TASK
{
    IFS=$'\n' read -r -d '' SWEEP_DETAILS; RUN_ID=$(echo $SWEEP_DETAILS | sed -e "s/.*\[\([^]]*\)\].*/\1/g" -e "s/[\'\']//g")
    IFS=$'\n' read -r -d '' SWEEP_COMMAND;
} < <((printf '\0%s\0' "$(srun --ntasks=1 wandb agent --count 1 $SWEEP_ID)" 1>&2) 2>&1)

SWEEP_COMMAND="${SWEEP_COMMAND} --wandb_resume_version ${RUN_ID}"

# WAIT FOR ALL TASKS TO CATCH UP
wait

# RUN SWEEP COMMAND IN ALL TASKS
srun  $SWEEP_COMMAND

Python code using Pytorch Lightning DDP

... whatever code before

wandb_logger = pl.loggers.WandbLogger(name=args.experiment_name, version=args.wandb_resume_version, resume="must")

# once again, pytorch lightning requires setting:
# devices=number of allocated slurm gpus = number of slurm tasks 
# num_nodes= number of slurm nodes

trainer = pl.Trainer(devices=args.devices, num_nodes=args.num_nodes, 
                        accelerator='gpu', strategy='ddp_find_unused_parameters_false', 
                        logger=wandb_logger)

... whatever training code after .....

So the basic flow of whats going on here:

  1. Make sweep configuration in sweep.yaml. Instead of the default python3 {args}, set the command to echo python3 {args}. Run wandb sweep sweep.yaml . Put the created sweep_id and your API key into the slurm script.

  2. When the SLURM script starts running, ALL of the ntasks tasks will execute wandb login $api_key.

  3. Only ONE task will execute wandb agent --count 1 $SWEEP_ID. This will create a run for the sweep and capture the echo python3 {args} via stdout and run_id via stderr.

  4. Because the wandb sweep command is echo rather than python3, the call to wandb agent will finish immediately. So we have to resume the run with RUN_ID in our pytorch lightning code. I use an argument called wandb_resume_version to do this.

  5. ALL tasks will execute the SWEEP_COMMAND.

So once again, this is a roundabout way to do this but I couldn't find any better solution. I'll also say that this would be greatly simplified if there was a way to get run_id into the wandb command. Is it possible to add run_id like below in sweep.yaml?

command:
  - ${env}
  - echo
  - python3
  - ${program}
  - ${args}
  - ${run_id}
nate-wandb commented 1 year ago

Hi @OFSkean, I can make a feature request to add the - ${run_id} to the command. I'm glad you have a workaround even though it's not ideal in the meantime.

I can also add that a general way to run SLURM + SWEEP + DDP is requested as well since the above is more of a workaround it seems rather than an official way to run this.

OFSkean commented 1 year ago

Hi @nate-wandb, yes please make a feature request for ${run_id}. It would be great if it adds the run_id as an argument to the command such as python3 main.py --random-arg 42 --run_id abcdefg.

nate-wandb commented 1 year ago

Ok, I've submitted this to the team and can follow up once they have a chance to look into this.

TheLukaDragar commented 1 year ago

Hi, any updates yet?

nate-wandb commented 1 year ago

Hi @TheLukaDragar, we are currently working on better supporting Sweeps on Slurm as the solution but unfortunately this work is not expected to land until early 2024

sararb commented 9 months ago

Hi @nate-wandb, any updates on supporting Sweeps on Slurm?

nate-wandb commented 9 months ago

Hi @sararb, the solution for this is going to be enabling our launch product to run Slurm jobs. Unfortunately, this is still potentially a few quarters away. I'm bumping the priority on this though since this has been requested several times to see if we can get this planned sooner. I'll provide an update as soon as there is any progress on this

exalate-issue-sync[bot] commented 5 months ago

WandB Internal User commented: thesofakillers commented: I can't even get WandB Sweep + SLURM + Pytorch Lightning (1 GPU) working. I just get "ValueError: signal only works in main thread"

MostHumble commented 4 months ago

Hi @sararb, the solution for this is going to be enabling our launch product to run Slurm jobs. Unfortunately, this is still potentially a few quarters away. I'm bumping the priority on this though since this has been requested several times to see if we can get this planned sooner. I'll provide an update as soon as there is any progress on this

Hi @nate-wandb any updates on this?

ArtsiomWB commented 4 months ago

Hey @MostHumble, no updates on this as of yet. We are still working on Sweeps/Launch on SLURM and it is the way to fix this, but we still not have an ETA, apologies for the delay.

Kin-Zhang commented 3 months ago

my first instinct would be to only use a single task and use Trainer(accelerator="gpu", devices=4, strategy="ddp") in your train.py to spin up the parallel GPU processes. You will need to put wandb.init() in an if block like this: From: https://github.com/wandb/wandb/issues/5695#issuecomment-1587958030

I tried this way, but since I'm using Python CLI and if we put only wandb_init in rank 0 the next sweep agent cannot initial in multiple GPUs...