octo-models / octo

Octo is a transformer-based robot policy trained on a diverse mix of 800k robot trajectories.
https://octo-models.github.io/
MIT License
885 stars 166 forks source link

Training on a GPU cluster #84

Open budzianowski opened 6 months ago

budzianowski commented 6 months ago

Hi team, thank you for sharing this fantastic work. I initialize the cluster with jax.distributed.initialize() and run below command:

!/bin/bash

SBATCH --job-name=octo_train # Job name

SBATCH --nodes=2

SBATCH --ntasks=16

SBATCH --ntasks-per-node=8 # Number of nodes

SBATCH --nodelist=compute-permanent-node-493,compute-permanent-node-580

SBATCH --gpus-per-node=8 # Request 1 GPU (adjust as needed)

SBATCH --time=12:00:00 # Time limit hrs:min:sec

srun python scripts/finetune.py --config.pretrained_path=hf://rail-berkeley/octo-small --debug

and running into (I suppose data loader issue) AssertionError: horizon must be <= max_horizon (it's 256 vs 10) which tells me the batch is not split. Have you experienced a similar issue or trained only on tpus before?