araffin / sbx

SBX: Stable Baselines Jax (SB3 + Jax)
MIT License
299 stars 31 forks source link

[Bug] TQC Hyperparameter optimization: Results do not match the reference. This is likely a bug/unexpected loss of precision. #44

Closed edmund735 closed 2 weeks ago

edmund735 commented 3 months ago

🐛 Bug

Hi,

When I try to run TQC hyperparameter optimization with multiple jobs (n-jobs>1) with a GPU (this also happens with multiple CPU cores and n-jobs=1), it gives me this error:

2024-04-07 14:35:59.992779: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 0: -inf, expected -0.000287323
2024-04-07 14:35:59.992804: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 1: -inf, expected -0.000267224
2024-04-07 14:35:59.992808: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 2: -inf, expected -0.000226477
2024-04-07 14:35:59.992811: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 3: -inf, expected -0.000281823
2024-04-07 14:35:59.992813: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 4: -inf, expected -0.000262532
2024-04-07 14:35:59.992815: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 5: -inf, expected -0.000252724
2024-04-07 14:35:59.992818: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 6: -inf, expected -0.000250007
2024-04-07 14:35:59.992820: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 7: -inf, expected -0.000265674
2024-04-07 14:35:59.992823: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 8: -inf, expected -0.00021464
2024-04-07 14:35:59.992825: E external/xla/xla/service/gpu/buffer_comparator.cc:143] Difference at 9: -inf, expected -0.000204733
E0407 14:35:59.992828  798907 triton_autotuner.cc:766] Results do not match the reference. This is likely a bug/unexpected loss of precision.

To Reproduce

python rl-baselines3-zoo/train_sbx.py --algo tqc --env Pendulum-v1 -n 5000 --n-trials 50 --num-threads 1 --n-jobs 4 --log-interval 4900 --eval-episodes 16 --n-eval-envs 8 --seed 8 --vec-env "dummy" -optimize --sampler tpe --pruner median --n-startup-trials 10
[W 2024-04-07 14:36:00,208] Trial 16 failed with parameters: {'gamma': 0.995, 'learning_rate': 0.23149128592335125, 'batch_size': 1024, 'buffer_size': 10000, 'learning_starts': 1000, 'train_freq': 16, 'tau': 0.08, 'log_std_init': -0.3684256821552643, 'net_arch': 'medium', 'n_quantiles': 32, 'top_quantiles_to_drop_per_net': 30} because of the following error: XlaRuntimeError('INTERNAL: All algorithms tried for %dot.384 = f32[1024,512]{1,0} dot(f32[1024,32]{1,0} %broadcast.2180, f32[512,32]{1,0} %parameter_0.14), lhs_contracting_dims={1}, rhs_contracting_dims={1}, metadata={op_name="jit(_train)/jit(main)/while/body/cond/branch_1_fun/jit(update_actor)/transpose(jvp(Critic))/Dense_2/dot_general[dimension_numbers=(((1,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/scratch/network/.../rl-baselines3-zoo/rl_zoo3/exp_manager.py" source_line=793} failed. Falling back to default algorithm.  Per-algorithm errors:\n  Results do not match the reference. This is likely a bug/unexpected loss of precision.

Traceback (most recent call last):
  File "/home/.conda/envs/.../lib/python3.10/site-packages/optuna/study/_optimize.py", line 196, in _run_trial
    value_or_values = func(trial)
  File "/scratch/network/.../rl-baselines3-zoo/rl_zoo3/exp_manager.py", line 793, in objective
    model.learn(self.n_timesteps, callback=callbacks, **learn_kwargs)  # type: ignore[arg-type]
  File "/home/.conda/envs/.../lib/python3.10/site-packages/sbx/tqc/tqc.py", line 183, in learn
    return super().learn(
  File "/home/.conda/envs/.../lib/python3.10/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 347, in learn
    self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)
  File "/home/.conda/envs/.../lib/python3.10/site-packages/sbx/tqc/tqc.py", line 220, in train
    ) = self._train(
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: All algorithms tried for %dot.384 = f32[1024,512]{1,0} dot(f32[1024,32]{1,0} %broadcast.2180, f32[512,32]{1,0} %parameter_0.14), lhs_contracting_dims={1}, rhs_contracting_dims={1}, metadata={op_name="jit(_train)/jit(main)/while/body/cond/branch_1_fun/jit(update_actor)/transpose(jvp(Critic))/Dense_2/dot_general[dimension_numbers=(((1,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/scratch/network/.../rl-baselines3-zoo/rl_zoo3/exp_manager.py" source_line=793} failed. Falling back to default algorithm.

Traceback (most recent call last):
File "/scratch/network/.../.../rl-baselines3-zoo/train_sbx.py", line 19, in <module>
train()
File "/scratch/network/.../.../rl-baselines3-zoo/rl_zoo3/train.py", line 275, in train
exp_manager.hyperparameters_optimization()
File "/scratch/network/.../.../rl-baselines3-zoo/rl_zoo3/exp_manager.py", line 874, in hyperparameters_optimization
study.optimize(self.objective, n_jobs=self.n_jobs, n_trials=self.n_trials)
File "/home/.../.conda/envs/.../lib/python3.10/site-packages/optuna/study/study.py", line 451, in optimize
_optimize(
File "/home/.../.conda/envs/.../lib/python3.10/site-packages/optuna/study/_optimize.py", line 99, in _optimize
f.result()
File "/home/.../.conda/envs/.../lib/python3.10/concurrent/futures/_base.py", line 451, in result
return self.__get_result()
File "/home/.../.conda/envs/.../lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
raise self._exception
File "/home/.../.conda/envs/.../lib/python3.10/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
File "/home/.../.conda/envs/.../lib/python3.10/site-packages/optuna/study/_optimize.py", line 159, in _optimize_sequential
frozen_trial = _run_trial(study, func, catch)
File "/home/.../.conda/envs/.../lib/python3.10/site-packages/optuna/study/_optimize.py", line 247, in _run_trial
raise func_err
File "/home/.../.conda/envs/.../lib/python3.10/site-packages/optuna/study/_optimize.py", line 196, in _run_trial
value_or_values = func(trial)
File "/scratch/network/.../.../rl-baselines3-zoo/rl_zoo3/exp_manager.py", line 793, in objective
model.learn(self.n_timesteps, callback=callbacks, **learn_kwargs) # type: ignore[arg-type]
File "/home/.../.conda/envs/.../lib/python3.10/site-packages/sbx/tqc/tqc.py", line 183, in learn
return super().learn(
File "/home/.../.conda/envs/.../lib/python3.10/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 347, in learn
self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)
File "/home/.../.conda/envs/.../lib/python3.10/site-packages/sbx/tqc/tqc.py", line 220, in train
) = self._train(
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: All algorithms tried for %dot.384 = f32[1024,512]{1,0} dot(f32[1024,32]{1,0} %broadcast.2180, f32[512,32]{1,0} %parameter_0.14), lhs_contracting_dims={1}, rhs_contracting_dims={1}, metadata={op_name="jit(_train)/jit(main)/while/body/cond/branch_1_fun/jit(update_actor)/transpose(jvp(Critic))/Dense_2/dot_general[dimension_numbers=(((1,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/scratch/network/.../.../rl-baselines3-zoo/rl_zoo3/exp_manager.py" source_line=793} failed. Falling back to default algorithm.

 System Info

Describe the characteristic of your environment:

Additional context

I've noticed there's no bug when n-jobs=1, only when running multiple jobs. Maybe something with the way Optuna runs multiple jobs?

Checklist

edmund735 commented 3 months ago

I did this again with an new version of jax (jaxlib 0.4.23 cuda120py310h3cc97ca_20) and it gives a new error now: ''' [I 2024-04-07 15:55:02,691] A new study created in memory with name: no-name-0da03417-c265-43a2-a55b-10d9750abcca 2024-04-07 15:55:23.576297: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error: INTERNAL: Failed to launch CUDA kernel: triton_gemm_dot_74; block dims: 256x1x1; grid dims: 32x1x1; shared memory size: 24576: CUDA_ERROR_STREAM_CAPTURE_INVALIDATED: operation failed due to a previous error during capture 2024-04-07 15:55:23.576337: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error: INTERNAL: CaptureGpuGraph failed (Failed to launch CUDA kernel: triton_gemm_dot_74; block dims: 256x1x1; grid dims: 32x1x1; shared memory size: 24576: CUDA_ERROR_STREAM_CAPTURE_INVALIDATED: operation failed due to a previous error during capture; current tracing scope: triton_gemm_dot.86): INTERNAL: Failed to end stream capture: CUDA_ERROR_STREAM_CAPTURE_INVALIDATED: operation failed due to a previous error during capture 2024-04-07 15:55:23.576394: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.graph.launch' failed: CaptureGpuGraph failed (Failed to launch CUDA kernel: triton_gemm_dot_74; block dims: 256x1x1; grid dims: 32x1x1; shared memory size: 24576: CUDA_ERROR_STREAM_CAPTURE_INVALIDATED: operation failed due to a previous error during capture; current tracing scope: triton_gemm_dot.86): INTERNAL: Failed to end stream capture: CUDA_ERROR_STREAM_CAPTURE_INVALIDATED: operation failed due to a previous error during capture; current profiling annotation: XlaModule:#prefix=jit(_train)/jit(main)/while/body,hlo_module=jittrain,program_id=116#. INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.graph.launch' failed: CaptureGpuGraph failed (Failed to launch CUDA kernel: triton_gemm_dot_74; block dims: 256x1x1; grid dims: 32x1x1; shared memory size: 24576: CUDA_ERROR_STREAM_CAPTURE_INVALIDATED: operation failed due to a previous error during capture; current tracing scope: triton_gemm_dot.86): INTERNAL: Failed to end stream capture: CUDA_ERROR_STREAM_CAPTURE_INVALIDATED: operation failed due to a previous error during capture; current profiling annotation: XlaModule:#prefix=jit(_train)/jit(main)/while/body,hlo_module=jittrain,program_id=116#.

Sampled hyperparams: {'batch_size': 1024, 'buffer_size': 100000, 'ent_coef': 'auto', 'gamma': 0.9999, 'gradient_steps': 1, 'learning_rate': 0.004315216575412321, 'learning_starts': 0, 'policy_kwargs': {'log_std_init': -1.4239746627852474, 'n_quantiles': 31, 'net_arch': [64, 64], 'top_quantiles_to_drop_per_net': 25, 'use_sde': False}, 'target_entropy': 'auto', 'tau': 0.02, 'top_quantiles_to_drop_per_net': 25, 'train_freq': 1} 2024-04-07 15:55:23.577027: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:1883] could not synchronize on CUDA context: CUDA_ERROR_STREAM_CAPTURE_UNSUPPORTED: operation not permitted when stream is capturing :: Begin stack trace _PyObject_MakeTpCall

_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall

_PyObject_MakeTpCall
_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
PyObject_Call
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
PyObject_Call
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall

_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
PyObject_Call
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault

PyObject_Call
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall

_PyEval_EvalFrameDefault

_PyEval_EvalFrameDefault

_PyEval_EvalFrameDefault

PyObject_Call
_PyEval_EvalFrameDefault

_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
clone

End stack trace

[I 2024-04-07 15:55:23,577] Trial 1 pruned. [W 2024-04-07 15:55:23,606] Trial 3 failed with parameters: {'gamma': 1, 'learning_rate': 0.03739146141228411, 'batch_size': 256, 'buffer_size': 100000, 'learning_starts': 1000, 'train_freq': 1, 'tau': 0.02, 'log_std_init': -1.1735175685607313, 'net_arch': 'small', 'n_quantiles': 26, 'top_quantiles_to_drop_per_net': 24} because of the following error: XlaRuntimeError('INTERNAL: Failed to synchronize GPU for autotuning.'). jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/home/.../.conda/envs/...1/lib/python3.10/site-packages/optuna/study/_optimize.py", line 196, in _run_trial value_or_values = func(trial) File "/scratch/network/.../.../rl-baselines3-zoo/rl_zoo3/exp_manager.py", line 793, in objective model.learn(self.n_timesteps, callback=callbacks, **learn_kwargs) # type: ignore[arg-type] File "/home/.../.conda/envs/...1/lib/python3.10/site-packages/sbx/tqc/tqc.py", line 183, in learn return super().learn( File "/home/.../.conda/envs/...1/lib/python3.10/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 347, in learn self.train(batch_size=self.batch_size, gradient_steps=gradient_steps) File "/home/.../.conda/envs/...1/lib/python3.10/site-packages/sbx/tqc/tqc.py", line 220, in train ) = self._train( jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to synchronize GPU for autotuning.

[W 2024-04-07 15:55:23,607] Trial 3 failed with value None. [I 2024-04-07 15:55:44,276] Trial 2 finished with value: -149.14224087500003 and parameters: {'gamma': 0.995, 'learning_rate': 0.005830150992686316, 'batch_size': 2048, 'buffer_size': 10000, 'learning_starts': 0, 'train_freq': 8, 'tau': 0.01, 'log_std_init': -3.101106181907312, 'net_arch': 'medium', 'n_quantiles': 13, 'top_quantiles_to_drop_per_net': 1}. Best is trial 2 with value: -149.14224087500003. [I 2024-04-07 15:55:44,442] Trial 0 finished with value: -1286.9111508125 and parameters: {'gamma': 0.995, 'learning_rate': 0.03380452664776398, 'batch_size': 128, 'buffer_size': 1000000, 'learning_starts': 1000, 'train_freq': 4, 'tau': 0.02, 'log_std_init': -3.2686941182290763, 'net_arch': 'big', 'n_quantiles': 45, 'top_quantiles_to_drop_per_net': 13}. Best is trial 2 with value: -149.14224087500003. [I 2024-04-07 15:55:44,610] Trial 4 finished with value: -408.14151849999996 and parameters: {'gamma': 0.99, 'learning_rate': 0.022024554072114278, 'batch_size': 512, 'buffer_size': 100000, 'learning_starts': 1000, 'train_freq': 8, 'tau': 0.02, 'log_std_init': 0.9307981026739451, 'net_arch': 'big', 'n_quantiles': 35, 'top_quantiles_to_drop_per_net': 29}. Best is trial 2 with value: -149.14224087500003. jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/scratch/network/.../.../rl-baselines3-zoo/train_sbx.py", line 19, in train() File "/scratch/network/.../.../rl-baselines3-zoo/rl_zoo3/train.py", line 275, in train exp_manager.hyperparameters_optimization() File "/scratch/network/.../.../rl-baselines3-zoo/rl_zoo3/exp_manager.py", line 874, in hyperparameters_optimization study.optimize(self.objective, n_jobs=self.n_jobs, n_trials=self.n_trials) File "/home/.../.conda/envs/...1/lib/python3.10/site-packages/optuna/study/study.py", line 451, in optimize _optimize( File "/home/.../.conda/envs/...1/lib/python3.10/site-packages/optuna/study/_optimize.py", line 99, in _optimize f.result() File "/home/.../.conda/envs/...1/lib/python3.10/concurrent/futures/_base.py", line 451, in result return self.get_result() File "/home/.../.conda/envs/...1/lib/python3.10/concurrent/futures/_base.py", line 403, in get_result raise self._exception File "/home/.../.conda/envs/...1/lib/python3.10/concurrent/futures/thread.py", line 58, in run result = self.fn(*self.args, self.kwargs) File "/home/.../.conda/envs/...1/lib/python3.10/site-packages/optuna/study/_optimize.py", line 159, in _optimize_sequential frozen_trial = _run_trial(study, func, catch) File "/home/.../.conda/envs/...1/lib/python3.10/site-packages/optuna/study/_optimize.py", line 247, in _run_trial raise func_err File "/home/.../.conda/envs/...1/lib/python3.10/site-packages/optuna/study/_optimize.py", line 196, in _run_trial value_or_values = func(trial) File "/scratch/network/.../.../rl-baselines3-zoo/rl_zoo3/exp_manager.py", line 793, in objective model.learn(self.n_timesteps, callback=callbacks, learn_kwargs) # type: ignore[arg-type] File "/home/.../.conda/envs/...1/lib/python3.10/site-packages/sbx/tqc/tqc.py", line 183, in learn return super().learn( File "/home/.../.conda/envs/...1/lib/python3.10/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 347, in learn self.train(batch_size=self.batch_size, gradient_steps=gradient_steps) File "/home/.../.conda/envs/...1/lib/python3.10/site-packages/sbx/tqc/tqc.py", line 220, in train

) = self._train( jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to synchronize GPU for autotuning.

'''

araffin commented 3 months ago

This might be related to Jax not handling multi-threading/multi-processing well.

You should probably have a look at distributed tuning using a shared database (I would recommend the log format): https://rl-baselines3-zoo.readthedocs.io/en/master/guide/tuning.html