sail-sg / envpool

C++-based high-performance parallel environment execution engine (vectorized env) for general RL environments.
https://envpool.readthedocs.io
Apache License 2.0
1.1k stars 100 forks source link

[BUG] XLA Segmentation Fault #283

Closed ethanluoyc closed 1 year ago

ethanluoyc commented 1 year ago

Describe the bug

A clear and concise description of what the bug is.

To Reproduce

The following code using the XLA interface crashes when running on the GPU.

from typing import Any, NamedTuple
from absl import app
import dataclasses
from absl import logging
from typing import Optional
import os
import time

# envpool only accept double type action input
os.environ["JAX_DEFAULT_DTYPE_BITS"] = "32"
# see https://github.com/google/jax/discussions/6332#discussioncomment-1279991
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.7"
import envpool
import jax
import flax
import jax.numpy as jnp

class RolloutCarry(NamedTuple):
    handle: Any
    state: Any
    key: Any

@flax.struct.dataclass
class RolloutOutput:
    actions: jnp.array
    timestep: Any

def rollout(env_step_fn, policy, agent_state, rollout_carry, max_steps):
    def _step(carry, timestep):
        del timestep

        rollout_carry = carry
        action_key, key = jax.random.split(rollout_carry.key)
        action = policy(agent_state, rollout_carry.state.observation, action_key)
        handle, next_state = env_step_fn(rollout_carry.handle, action)
        output = RolloutOutput(
            actions=action,
            timestep=rollout_carry.state,
        )
        new_rollout_carry = RolloutCarry(handle=handle, state=next_state, key=key)
        return (new_rollout_carry, output)

    new_rollout_carry, output = jax.lax.scan(_step, rollout_carry, (), length=max_steps)
    return (new_rollout_carry, output)

def main(_):
    num_envs = 64
    num_steps = 32
    total_timesteps = int(3e6)
    num_updates = total_timesteps // (num_envs * num_steps)

    envs = envpool.make("HalfCheetah-v3", env_type="dm", num_envs=num_envs, seed=1)
    # envs = envpool.make("CheetahRun-v1", env_type="dm", num_envs=num_envs, seed=1)
    action_spec = envs.action_spec()

    handle, _, _, step_env = envs.xla()
    state = envs.reset()

    def process_states(states):
        # I am converting the observation to single precision here but that seems to be the line that causes the crash.
        return states._replace(
            observation=jnp.array(states.observation.obs, dtype=jnp.float32, copy=True)
        )

    params = ()
    carry = RolloutCarry(
        handle=handle,
        state=process_states(state),
        key=jax.random.PRNGKey(0),
    )

    def wrapped_step_env(handle, action):
        handle, state = step_env(handle, action)
        return handle, process_states(state)

    def policy(params, obs, key):
        I do use float64 for actions.
        return jax.random.uniform(
            key, (num_envs, action_spec.shape[0]), dtype=jnp.float64
        )

    @jax.jit
    def rollout_fn(agent_state, rollout_carry):
        return rollout(wrapped_step_env, policy, agent_state, rollout_carry, num_steps)

    global_step = 0
    for _ in range(1, num_updates + 1):
        update_time_start = time.time()
        carry, experience = rollout_fn(params, carry)
        global_step += num_steps * num_envs
        sps_update = int(num_envs * num_steps / (time.time() - update_time_start))

        jax.block_until_ready(experience)
        # logging.info("global_step=%d, SPS_update=%d", global_step, sps_update)

    envs.close()

if __name__ == "__main__":
    jax.config.update("jax_default_dtype_bits", "32")
    jax.config.update("jax_enable_x64", True)
    jax.config.config_with_absl()
    app.run(main)
I1025 11:09:56.088881 139954427008832 xla_bridge.py:455] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA Interpreter Host
I1025 11:09:56.089233 139954427008832 xla_bridge.py:455] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
I1025 11:09:56.089317 139954427008832 xla_bridge.py:455] Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
Fatal Python error: Segmentation fault

Current thread 0x00007f49ade81740 (most recent call first):
  File "/home/yicheng/projects/corax-mjx/ppo_jax/debug.py", line 93 in main
  File "/home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/absl/app.py", line 254 in _run_main
  File "/home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/absl/app.py", line 308 in run
  File "/home/yicheng/projects/corax-mjx/ppo_jax/debug.py", line 107 in <module>

Expected behavior

A clear and concise description of what you expected to happen.

Screenshots

If applicable, add screenshots to help explain your problem.

System info

Describe the characteristic of your environment:

0.8.3 1.26.1 3.10.5 (main, Jun 19 2023, 14:30:29) [GCC 9.4.0] linux

JAX 0.4.10.

import envpool, numpy, sys
print(envpool.__version__, numpy.__version__, sys.version, sys.platform)

Additional context

I ran under gdb, this is the backtrace

0x00007fff6856c196 in void AsyncEnvPool<mujoco_gym::HalfCheetahEnv>::SendImpl<std::vector<Array, std::allocator<Array> > const&>(std::vector<Array, std::allocator<Array> > const&) () from /home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/envpool/mujoco/mujoco_gym_envpool.so
(gdb) bt
#0  0x00007fff6856c196 in void AsyncEnvPool<mujoco_gym::HalfCheetahEnv>::SendImpl<std::vector<Array, std::allocator<Array> > const&>(std::vector<Array, std::allocator<Array> > const&) ()
   from /home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/envpool/mujoco/mujoco_gym_envpool.so
#1  0x00007fff6856ca18 in CustomCall<AsyncEnvPool<mujoco_gym::HalfCheetahEnv>, XlaSend<AsyncEnvPool<mujoco_gym::HalfCheetahEnv> > >::Gpu(CUstream_st*, void**, char const*, unsigned long) ()
   from /home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/envpool/mujoco/mujoco_gym_envpool.so
#2  0x00007fff70a158e0 in xla::runtime::CustomCallHandler<(xla::runtime::CustomCall::RuntimeChecks)1, xla::runtime::CustomCall::FunctionWrapper<&xla::gpu::XlaCustomCallImpl>, xla::runtime::internal::UserData<xla::ServiceExecutableRunOptions const*>, xla::runtime::internal::UserData<xla::DebugOptions const*>, xla::runtime::CustomCall::RemainingArgs, xla::runtime::internal::Attr<std::basic_string_view<char, std::char_traits<char> > >, xla::runtime::internal::Attr<int>, xla::runtime::internal::Attr<std::basic_string_view<char, std::char_traits<char> > > >::call(void**, void**, void**, xla::runtime::PtrMapByType<xla::runtime::CustomCall, 16u> const*, xla::runtime::DiagnosticEngine const*) const ()
   from /home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#3  0x00007fff70a16502 in xla::gpu::XlaCustomCall(xla::runtime::ExecutionContext*, void**, void**, void**) ()
   from /home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#4  0x00007fff600f21b0 in __xla__main.192 ()
#5  0x00007fff70c90557 in xla::runtime::Executable::Execute(unsigned int, xla::runtime::Executable::CallFrame&, xla::runtime::Executable::ExecuteOpts const&) const () from /home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#6  0x00007fff7093e5b3 in xla::gpu::GpuRuntimeExecutable::Execute(xla::ServiceExecutableRunOptions const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::vector<unsigned char, std::allocator<unsigned char> > const&, xla::gpu::BufferAllocations const&, xla::gpu::NonAtomicallyUpgradeableRWLock&, xla::BufferAllocation const*) ()
   from /home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#7  0x00007fff70924327 in xla::gpu::GpuExecutable::ExecuteThunksOrXlaRuntime(xla::ServiceExecutableRunOptions const*, xla::gpu::BufferAllocations const&, bool, xla::gpu::NonAtomicallyUpgradeableRWLock&) ()
   from /home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#8  0x00007fff7092868a in xla::gpu::GpuExecutable::ExecuteAsyncOnStreamImpl(xla::ServiceExecutableRunOptions const*, std::variant<absl::lts_20230125::Span<xla::ShapedBuffer const* const>, absl::lts_20230125::Span<xla::ExecutionInput> >) ()
   from /home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#9  0x00007fff70929630 in xla::gpu::GpuExecutable::ExecuteAsyncOnStream(xla::ServiceExecutableRunOptions const*, std::vector<xla::ExecutionInput, std::allocator<xla::ExecutionInput> >, xla::HloExecutionProfile*) ()
   from /home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#10 0x00007fff7227f2e7 in xla::Executable::ExecuteAsyncOnStreamWrapper(xla::ServiceExecutableRunOptions const*, std::vector<xla::ExecutionInput, std::allocator<xla::ExecutionInput> >) () from /home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#11 0x00007fff6f677f43 in xla::LocalExecutable::RunAsync(absl::lts_20230125::Span<xla::Shape const* const>, std::vector<xla::ExecutionInput, std::allocator<xla::ExecutionInput> >, xla::ExecutableRunOptions) ()
   from /home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/jaxlib/xla_extension.so
#12 0x00007fff6f678ba5 in xla::LocalExecutable::RunAsync(std::vector<xla::ExecutionInput, std::allocator<xla::ExecutionInput> >, xla::ExecutableRunOptions) () from /home/yicheng/projects/corax-mjx/.venv/lib/python3.10/site-packages/jaxlib/xla_extension.so

Reason and Possible fixes

If you know or suspect the reason for this bug, paste the code lines and suggest modifications.

Checklist

mavenlin commented 1 year ago

I am converting the observation to single precision here but that seems to be the line that causes the crash.

Do you mean that if this line is remove then the crash goes away?

ethanluoyc commented 1 year ago

I am converting the observation to single precision here but that seems to be the line that causes the crash.

Do you mean that if this line is remove then the crash goes away?

Initially I thought so but then looks like it's flaky. But I guess you have found the issue?

mavenlin commented 1 year ago

Yep, #284 should fix it.

ethanluoyc commented 1 year ago

@mavenlin Hmm I tried it on my side but that issue seems to persist, I will take a closer look at the setup on my side.

mavenlin commented 1 year ago

Hmm I tried it on my side but that issue seems to persist, I will take a closer look at the setup on my side.

I tested the wheel from here. I can run your above code without an issue.

ethanluoyc commented 1 year ago

Yeah it seems to work. I was experimenting with PDM and that seems to have messed up my pip installation somehow. Many thanks for fixing this! It would be super cool if there is a new release on PyPI.

Trinkle23897 commented 1 year ago

will do this weekend, sorry for the delay

Trinkle23897 commented 1 year ago

done, pip install envpool will now use 0.8.4