facebookresearch / minimax

Efficient baselines for autocurricula in JAX.
Apache License 2.0
172 stars 14 forks source link

[DRRunner] corrected the negative learning rate in the schedule_function in Domain Randomisation Runner #10

Open RobbenRibery opened 2 months ago

RobbenRibery commented 2 months ago

In the reset(self, rng) method, the learning rate seems negative as initially specified. This triggers the learning to break down completely. After turning it into a positive value, pass the scheduler into the optax chain (see line 153). ACCEL achieves generalisation on OOD envs [ref: WANDB attached]

Screenshot 2024-08-24 at 19 21 20
minqi commented 2 months ago

Hi @RobbenRibery, that schedule_fn is a leftover from code we did not use in our experiments (the original L153 in your diff uses float(self.lr without the negative sign.)

Looking at optax.linear_schedule it looks like your change should default correctly to a constant function returning the initial learning rate if self.anneal_steps == 0, so I think this is safe to merge. @samvelyan

RobbenRibery commented 2 months ago

Hi @minqi, thanks for your comment! I see your point. We can enforce something like self.anneal_steps == 0 or self.lr_final == self.lr

Happy to run some experiments to see if annealing help further stablise the training.

minqi commented 2 months ago

Hi @RobbenRibery, the default setting for self.anneal_steps is 0, and for self.lr_final it is None, in which case it defaults to the same value as self.lr, so no changes there are necessary.

We previously looked at linear annealing, but found it mostly hurt final policy performance on OOD tasks.

RobbenRibery commented 2 months ago

Thanks, appreciated!

RobbenRibery commented 2 months ago

Hi Minqi, @minqi, I also find that by setting the following:

export XLA_FLAGS='--xla_gpu_deterministic_ops=true --xla_gpu_autotune_level=0' 
export TF_DETERMINISTIC_OPS=1
python -m minimax.train -- ..... 

I could make the ACCEL runs deterministic at about 20% SPS compared to the non-deterministic runs. Otherwise, even if every RNG split is set correctly, I could still get different results.

ref wand attached:

Screenshot 2024-08-31 at 14 53 43