google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.31k stars 251 forks source link

SAC training throw segfault #542

Open varunagrawal opened 4 days ago

varunagrawal commented 4 days ago

I am trying to train an RL model using SAC and compare it to PPO by using the tutorial in this notebook, but I get a segfault when the training code reaches the line:

obs_size = env.observation_size

I am using jax==0.4.34, brax==0.11.0 on Ubuntu 22.04 with CUDA 12.6 and Driver Version: 560.35.03.

Here is my script:

"""
Train Barkour model via Soft Actor Critic (SAC) in an environment with no obstacle.

The environment initializes the robot at a random start point
normally distributed around a mean and the robot has to
reach the same goal point.

python scripts/train_barkour_straight_sac.py
"""

import functools
from pathlib import Path

import jax
from brax import envs
from brax.io import model
from brax.training.agents.sac import networks as sac_networks
from brax.training.agents.sac import train as sac

from fill.envs import domain_randomize
from fill.utils.progress import Progress, print_progress_times
from fill.utils.video import render_video

def main():
    """Main training code."""

    env_name = 'barkour_straight'

    env = envs.get_environment(env_name)

    make_networks_factory = functools.partial(sac_networks.make_sac_networks,
                                              hidden_layer_sizes=(128, 128,
                                                                  128, 128))

    num_timesteps = 80_000_000

    train_fn = functools.partial(
        sac.train,
        num_timesteps=num_timesteps,
        episode_length=env.eps_length,
        action_repeat=1,
        num_envs=64,
        learning_rate=3.0e-4,
        discounting=0.97,
        seed=0,
        batch_size=256,
        num_evals=10,
        normalize_observations=True,
        reward_scaling=1,
        min_replay_size=200,
        max_replay_size=40_000,
        network_factory=make_networks_factory,
        randomization_fn=domain_randomize,
    )

    # Reset environments since internals may be overwritten by tracers from the
    # domain randomization function.
    env = envs.get_environment(env_name)
    eval_env = envs.get_environment(env_name)

    eval_env.curriculum_level = 0  #np.random.randint(2)

    save_path = Path("results") / "train_barkour_straight_sac"
    # Create save_path if it doesn't exits
    save_path.mkdir(parents=True, exist_ok=True)

    progress = Progress(num_timesteps=num_timesteps,
                        save_path=save_path / 'graph')
    make_inference_fn, params, _ = train_fn(environment=env,
                                            progress_fn=progress,
                                            eval_env=eval_env)

    # print_progress_times(progress)

    # # Save and reload params.
    # model_path = save_path / 'mjx_brax_quadruped_policy'
    # model.save_params(model_path, params)
    # print(f"Loading model from {model_path}")
    # params = model.load_params(model_path)

    # # Visualize trained policy
    # eval_env = envs.get_environment(env_name)

    # # initialize the state
    # rng = jax.random.PRNGKey(0)

    # inference_fn = make_inference_fn(params)

    # render_video(eval_env, inference_fn, save_path, rng)

if __name__ == "__main__":
    main()

Interestingly, the same script works great when using PPO, where I change sac.train to ppo.train and sac_networks.make_sac_networks to ppo_networks.make_ppo_networks, which is why I think the issue is either in how SAC uses my custom environment, or something within SAC itself.

Using faulthandler, I get the trace of the segfault as

INFO:absl:local_device_count: 1; total_device_count: 1
Fatal Python error: Segmentation fault

Current thread 0x000072b2f2aa7b80 (most recent call first):
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/jax/_src/compiler.py", line 267 in backend_compile
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/jax/_src/profiler.py", line 333 in wrapper
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/jax/_src/compiler.py", line 655 in _compile_and_write_cache
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/jax/_src/compiler.py", line 427 in compile_or_get_cached
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2641 in _cached_compilation
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2829 in from_hlo
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2315 in compile
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/jax/_src/pjit.py", line 1669 in _pjit_call_impl_python
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/jax/_src/pjit.py", line 1739 in call_impl_cache_miss
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/jax/_src/pjit.py", line 1764 in _pjit_call_impl
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/jax/_src/core.py", line 948 in process_primitive
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/jax/_src/core.py", line 442 in bind_with_trace
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/jax/_src/core.py", line 2781 in bind
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/jax/_src/pjit.py", line 2003 in _pjit_batcher
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/jax/_src/interpreters/batching.py", line 442 in process_primitive
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/jax/_src/core.py", line 442 in bind_with_trace
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/jax/_src/core.py", line 2781 in bind
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/jax/_src/pjit.py", line 189 in _python_pjit_helper
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/jax/_src/pjit.py", line 356 in cache_miss
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 180 in reraise_with_filtered_traceback
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py", line 573 in deferring_binary_op
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/mujoco/mjx/_src/smooth.py", line 646 in _forward
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/mujoco/mjx/_src/scan.py", line 128 in outer_f
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/jax/_src/linear_util.py", line 193 in call_wrapped
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/jax/_src/api.py", line 992 in vmap_f
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 180 in reraise_with_filtered_traceback
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/mujoco/mjx/_src/scan.py", line 130 in _nvmap
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/mujoco/mjx/_src/scan.py", line 479 in body_tree
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/mujoco/mjx/_src/smooth.py", line 659 in rne_postconstraint
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/mujoco/mjx/_src/sensor.py", line 440 in sensor_acc
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/mujoco/mjx/_src/forward.py", line 400 in forward
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/mujoco/mjx/_src/forward.py", line 57 in wrapper
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/brax/mjx/pipeline.py", line 72 in init
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/brax/envs/base.py", line 123 in pipeline_init
  File "/home/varun/legged/fill/envs/barkour.py", line 266 in reset
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/brax/envs/base.py", line 144 in observation_size
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/brax/envs/base.py", line 180 in observation_size
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/brax/envs/base.py", line 180 in observation_size
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/brax/envs/base.py", line 180 in observation_size
  File "/home/varun/.pyenv/versions/3.11.7/lib/python3.11/site-packages/brax/training/agents/sac/train.py", line 193 in train
  File "/home/varun/legged/scripts/train_barkour_straight_sac.py", line 70 in main
  File "/home/varun/legged/scripts/train_barkour_straight_sac.py", line 94 in <module>

Extension modules: jaxlib.cpu_feature_guard, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, zstandard.backend_c, msgpack._cmsgpack, yaml._yaml, _cffi_backend, scipy._lib._ccallback_c, scipy.sparse._sparsetools, _csparsetools, scipy.sparse._csparsetools, scipy.linalg._fblas, scipy.linalg._flapack, scipy.linalg.cython_lapack, scipy.linalg._cythonized_array_utils, scipy.linalg._solve_toeplitz, scipy.linalg._flinalg, scipy.linalg._decomp_lu_cython, scipy.linalg._matfuncs_sqrtm_triu, scipy.linalg.cython_blas, scipy.linalg._matfuncs_expm, scipy.linalg._decomp_update, scipy.sparse.linalg._dsolve._superlu, scipy.sparse.linalg._eigen.arpack._arpack, scipy.sparse.csgraph._tools, scipy.sparse.csgraph._shortest_path, scipy.sparse.csgraph._traversal, scipy.sparse.csgraph._min_spanning_tree, scipy.sparse.csgraph._flow, scipy.sparse.csgraph._matching, scipy.sparse.csgraph._reordering, scipy.spatial._ckdtree, scipy._lib.messagestream, scipy.spatial._qhull, scipy.spatial._voronoi, scipy.spatial._distance_wrap, scipy.spatial._hausdorff, scipy.special._ufuncs_cxx, scipy.special._ufuncs, scipy.special._specfun, scipy.special._comb, scipy.special._ellip_harm_2, scipy.spatial.transform._rotation, xxhash._xxhash, scipy.optimize._minpack2, scipy.optimize._group_columns, scipy.optimize._trlib._trlib, scipy.optimize._lbfgsb, _moduleTNC, scipy.optimize._moduleTNC, scipy.optimize._cobyla, scipy.optimize._slsqp, scipy.optimize._minpack, scipy.optimize._lsq.givens_elimination, scipy.optimize._zeros, scipy.optimize._highs.cython.src._highs_wrapper, scipy.optimize._highs._highs_wrapper, scipy.optimize._highs.cython.src._highs_constants, scipy.optimize._highs._highs_constants, scipy.linalg._interpolative, scipy.optimize._bglu_dense, scipy.optimize._lsap, scipy.optimize._direct, psutil._psutil_linux, psutil._psutil_posix, PIL._imaging, embreex.rtcore, embreex.rtcore_scene, embreex.mesh_construction, lxml._elementpath, lxml.etree, shapely.lib, shapely._geos, shapely._geometry_helpers, scipy.ndimage._nd_image, _ni_label, scipy.ndimage._ni_label, numba.core.typeconv._typeconv, numba._helperlib, numba._dynfunc, numba._dispatcher, numba.core.runtime._nrt_python, numba.np.ufunc._internal, numba.experimental.jitclass._box, matplotlib._c_internal_utils, matplotlib._path, kiwisolver._cext, matplotlib._image, zmq.backend.cython.context, zmq.backend.cython.message, zmq.backend.cython.socket, zmq.backend.cython._device, zmq.backend.cython._poll, zmq.backend.cython._proxy_steerable, zmq.backend.cython._version, zmq.backend.cython.error, zmq.backend.cython.utils, tornado.speedups (total: 109)

I'll try to use the environment defined in this tutorial notebook to continue identifying the issue, I figured the amazing brax developers could point me in the right direction?

varunagrawal commented 3 days ago

Even with the default environment, I am still getting a segfault. I am truly lost now.