Open budzianowski opened 6 months ago
Hi team, thank you for sharing this fantastic work. I initialize the cluster with jax.distributed.initialize() and run below command:
jax.distributed.initialize()
!/bin/bash SBATCH --job-name=octo_train # Job name SBATCH --nodes=2 SBATCH --ntasks=16 SBATCH --ntasks-per-node=8 # Number of nodes SBATCH --nodelist=compute-permanent-node-493,compute-permanent-node-580 SBATCH --gpus-per-node=8 # Request 1 GPU (adjust as needed) SBATCH --time=12:00:00 # Time limit hrs:min:sec srun python scripts/finetune.py --config.pretrained_path=hf://rail-berkeley/octo-small --debug
srun python scripts/finetune.py --config.pretrained_path=hf://rail-berkeley/octo-small --debug
and running into (I suppose data loader issue) AssertionError: horizon must be <= max_horizon (it's 256 vs 10) which tells me the batch is not split. Have you experienced a similar issue or trained only on tpus before?
AssertionError: horizon must be <= max_horizon
Hi team, thank you for sharing this fantastic work. I initialize the cluster with
jax.distributed.initialize()
and run below command:and running into (I suppose data loader issue)
AssertionError: horizon must be <= max_horizon
(it's 256 vs 10) which tells me the batch is not split. Have you experienced a similar issue or trained only on tpus before?