TUM-DAML / seml

SEML: Slurm Experiment Management Library
Other
165 stars 29 forks source link

Add support for multi node/task jobs #135

Closed n-gao closed 3 months ago

n-gao commented 4 months ago

This PR adds support for multi-task/node jobs.

Changes

I made the following changes to support this:

Additional updates

Example JAX experiment:

import jax
import socket
from seml import Experiment
from jax._src.clusters.slurm_cluster import SlurmCluster

ex = Experiment()

@ex.automain
def main(seed=0):
    pid = SlurmCluster.get_process_id()
    address = SlurmCluster.get_coordinator_address()
    host, port = address.split(":")
    host = socket.gethostbyname(host)
    jax.distributed.initialize(f"{host}:{port}")
    return jax.process_index()

YAML file

seml:
  executable: test.py
  name: distributed_test
  output_dir: logs
  project_root_dir: .

slurm:
  experiments_per_job: 1
  sbatch_options:
    mem: 16G
    cpus-per-task: 1
    time: 0-01:00
    partition: gpu_all
    gres: gpu:2
    nodes: 1
    ntasks: 2

fixed:
  seed: 0

This example starts to tasks on the same node each with 1 GPU.