araffin / sbx

SBX: Stable Baselines Jax (SB3 + Jax)
MIT License
345 stars 34 forks source link

[Bug] example supplied in readme crashing #9

Closed Robokan closed 7 months ago

Robokan commented 1 year ago

Important Note: We do not do technical support, nor consulting and don't answer personal questions per email. Please post your question on the RL Discord, Reddit or Stack Overflow in that case.

If your issue is related to a custom gym environment, please use the custom gym env template.

🐛 Bug

A clear and concise description of what the bug is.

If I run the example in your readme it crashes

To Reproduce

Steps to reproduce the behavior.

import gym

from sbx import TQC, DroQ, SAC, PPO, DQN

env = gym.make("Pendulum-v1")

model = TQC("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10_000, progress_bar=True)

vec_env = model.get_env() obs = vec_env.reset() for i in range(1000): action, _states = model.predict(obs, deterministic=True) obs, reward, done, info = vec_env.step(action) vec_env.render()

vec_env.close()

Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.

Please use the markdown code blocks for both code and stack traces.

from stable_baselines3 import ...
Traceback (most recent call last): File ...

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [34], in <cell line: 7>()
      3 from sbx import TQC, DroQ, SAC, PPO, DQN
      5 env = gym.make("Pendulum-v1")
----> 7 model = TQC("MlpPolicy", env, verbose=1)
      8 model.learn(total_timesteps=10_000, progress_bar=True)
     10 vec_env = model.get_env()

File ~/miniconda3/envs/py39/lib/python3.9/site-packages/sbx/tqc/tqc.py:102, in TQC.__init__(self, policy, env, learning_rate, qf_learning_rate, buffer_size, learning_starts, batch_size, tau, gamma, train_freq, gradient_steps, policy_delay, top_quantiles_to_drop_per_net, action_noise, ent_coef, use_sde, sde_sample_freq, use_sde_at_warmup, tensorboard_log, policy_kwargs, verbose, seed, device, _init_setup_model)
     99 self.policy_kwargs["top_quantiles_to_drop_per_net"] = top_quantiles_to_drop_per_net
    101 if _init_setup_model:
--> 102     self._setup_model()

File ~/miniconda3/envs/py39/lib/python3.9/site-packages/sbx/tqc/tqc.py:115, in TQC._setup_model(self)
    107 if self.policy is None:
    108     self.policy = self.policy_class(  # pytype:disable=not-instantiable
    109         self.observation_space,
    110         self.action_space,
    111         self.lr_schedule,
    112         **self.policy_kwargs,  # pytype:disable=not-instantiable
    113     )
--> 115     self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate)
    117     self.key, ent_key = jax.random.split(self.key, 2)
    119     self.actor = self.policy.actor

File ~/miniconda3/envs/py39/lib/python3.9/site-packages/sbx/tqc/policies.py:142, in TQCPolicy.build(self, key, lr_schedule, qf_learning_rate)
    137 # Hack to make gSDE work without modifying internal SB3 code
    138 self.actor.reset_noise = self.reset_noise
    140 self.actor_state = TrainState.create(
    141     apply_fn=self.actor.apply,
--> 142     params=self.actor.init(actor_key, obs),
    143     tx=self.optimizer_class(learning_rate=lr_schedule(1), **self.optimizer_kwargs),
    144 )
    146 self.qf = Critic(
    147     dropout_rate=self.dropout_rate,
    148     use_layer_norm=self.layer_norm,
    149     n_units=self.n_units,
    150     n_quantiles=self.n_quantiles,
    151 )
    153 self.qf1_state = RLTrainState.create(
    154     apply_fn=self.qf.apply,
    155     params=self.qf.init(
   (...)
    165     tx=optax.adam(learning_rate=qf_learning_rate),
    166 )

    [... skipping hidden 9 frame]

File ~/miniconda3/envs/py39/lib/python3.9/site-packages/sbx/tqc/policies.py:66, in Actor.__call__(self, x)
     63 log_std = nn.Dense(self.action_dim)(x)
     64 log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max)
     65 dist = TanhTransformedDistribution(
---> 66     tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)),
     67 )
     68 return dist

File ~/miniconda3/envs/py39/lib/python3.9/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
    230 if not kwsyntax:
    231     args, kw = fix(args, kw, sig)
--> 232 return caller(func, *(extras + args), **kw)

File ~/miniconda3/envs/py39/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    339 # Note: if we ever want to have things set in `self` before `__init__` is
    340 # called, here is the place to do it.
    341 self_._parameters = None
--> 342 default_init(self_, *args, **kwargs)
    343 # Note: if we ever want to override things set in `self` by subclass
    344 # `__init__`, here is the place to do it.
    345 if self_._parameters is None:
    346   # We prefer subclasses will set `parameters = dict(locals())` because
    347   # this has nearly zero overhead. However, failing to do this, we will
    348   # resolve the input arguments dynamically and only when needed.

File ~/miniconda3/envs/py39/lib/python3.9/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_diag.py:235, in MultivariateNormalDiag.__init__(self, loc, scale_diag, scale_identity_multiplier, validate_args, allow_nan_stats, experimental_use_kahan_sum, name)
    232 if scale_diag is not None:
    233   diag_cls = (KahanLogDetLinOpDiag if experimental_use_kahan_sum else
    234               tf.linalg.LinearOperatorDiag)
--> 235   scale = diag_cls(
    236       diag=scale_diag,
    237       is_non_singular=True,
    238       is_self_adjoint=True,
    239       is_positive_definite=False)
    240 else:
    241   # Deprecated behavior; breaks variable-safety rules by calling
    242   # `tf.shape(loc)`.
    243   num_rows = tf.compat.dimension_value(loc.shape[-1])

File ~/miniconda3/envs/py39/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py:191, in LinearOperatorDiag.__init__(self, diag, is_non_singular, is_self_adjoint, is_positive_definite, is_square, name)
    182 super(LinearOperatorDiag, self).__init__(
    183     dtype=self._diag.dtype,
    184     is_non_singular=is_non_singular,
   (...)
    188     parameters=parameters,
    189     name=name)
    190 # TODO(b/143910018) Remove graph_parents in V3.
--> 191 self._set_graph_parents([self._diag])

File ~/miniconda3/envs/py39/lib/python3.9/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator.py:1177, in LinearOperator._set_graph_parents(self, graph_parents)
   1174 for i, t in enumerate(graph_parents):
   1175   if t is None or not (linear_operator_util.is_ref(t) or
   1176                        ops.is_tensor(t)):
-> 1177     raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
   1178 self._graph_parents = graph_parents

ValueError: Graph parent item 0 is not a Tensor; [[0.48654944]].
### Expected behavior

A clear and concise description of what you expected to happen.

if I import stable baselines PPO or others they train the example perfectly. Expecting SBX to do the same.

### System Info

Describe the characteristic of your environment:

Operating system is MacOS Monterey

 * Describe how the library was installed (pip, docker, source, ...)

pip

 * GPU models and configuration
 * Python version
 Python 3.9.15

 * PyTorch version
 * torch                      1.13.1
 * Gym version
  gym                        0.21.0
 * Versions of any other relevant libraries

You can use `sb3.get_system_info()` to print relevant packages info:
```python
import stable_baselines3 as sb3
sb3.get_system_info()

Additional context

Add any other context about the problem here.

Checklist

araffin commented 1 year ago

Hello, i guess if you use another os (for instance linux in a google colab), it does work? make sure to have latest version of tensorflow proba and jax, support for mac os might be experimental.

Robokan commented 1 year ago

I have the latest JAX working great on macOS. I am even running the Brax simulator on top of it. Your repo would be great for speeding up my stable-baselines/Gym code. I really appreciate you posting it. Any ideas of how to go about debugging this?

I did try it in on Google Colab and your example does work there.

araffin commented 1 year ago

I did try it in on Google Colab and your example does work there.

I have the latest JAX working great on macOS.

Then it seems that it comes from tensorflow probability (which is in jax only despite the name), you should probably open an issue there.

Robokan commented 1 year ago

Thanks. I will see if there is anything about this in JAX.