coax-dev / coax

Modular framework for Reinforcement Learning in python
https://coax.readthedocs.io
MIT License
166 stars 17 forks source link

Refactoring of ClippedDoubleQLearning for DSAC #8

Closed frederikschubert closed 2 years ago

frederikschubert commented 2 years ago

This PR refactors ClippedDoubleQLearning to use the functions of DoubleQLearning as a preparation for Distributional SAC DSAC.

frederikschubert commented 2 years ago

Right now this throws an error due to the different sizes of the parameters and function states of the target q functions, e.g. when running examples/pendulum/dsac.py.

ValueError: vmap got inconsistent sizes for array axes to be mapped:
the tree of axis sizes is:
(({'pi_targ': FlatMapping({
  'linear': FlatMapping({'b': 8, 'w': 3}),
  'linear_1': FlatMapping({'b': 8, 'w': 8}),
  'linear_2': FlatMapping({'b': 8, 'w': 8}),
  'linear_3': FlatMapping({'b': 2, 'w': 8}),
}), 'q': FlatMapping({'linear': FlatMapping({'b': 51, 'w': 3})}), 'q_targ': FlatMapping({'linear': FlatMapping({'b': 51, 'w': 3})})}, {'pi_targ': FlatMapping({
  'linear': FlatMapping({'b': 8, 'w': 3}),
  'linear_1': FlatMapping({'b': 8, 'w': 8}),
  'linear_2': FlatMapping({'b': 8, 'w': 8}),
  'linear_3': FlatMapping({'b': 2, 'w': 8}),
}), 'q': FlatMapping({'linear': FlatMapping({'b': 51, 'w': 3})}), 'q_targ': FlatMapping({'linear': FlatMapping({'b': 51, 'w': 3})})}), ({'pi_targ': FlatMapping({}), 'q': FlatMapping({}), 'q_targ': FlatMapping({})}, {'pi_targ': FlatMapping({}), 'q': FlatMapping({}), 'q_targ': FlatMapping({})}))
KristianHolsheimer commented 2 years ago

Right now this throws an error due to the different sizes of the parameters and function states of the target q functions, e.g. when running examples/pendulum/dsac.py.

ValueError: vmap got inconsistent sizes for array axes to be mapped:
the tree of axis sizes is:
(({'pi_targ': FlatMapping({
  'linear': FlatMapping({'b': 8, 'w': 3}),
  'linear_1': FlatMapping({'b': 8, 'w': 8}),
  'linear_2': FlatMapping({'b': 8, 'w': 8}),
  'linear_3': FlatMapping({'b': 2, 'w': 8}),
}), 'q': FlatMapping({'linear': FlatMapping({'b': 51, 'w': 3})}), 'q_targ': FlatMapping({'linear': FlatMapping({'b': 51, 'w': 3})})}, {'pi_targ': FlatMapping({
  'linear': FlatMapping({'b': 8, 'w': 3}),
  'linear_1': FlatMapping({'b': 8, 'w': 8}),
  'linear_2': FlatMapping({'b': 8, 'w': 8}),
  'linear_3': FlatMapping({'b': 2, 'w': 8}),
}), 'q': FlatMapping({'linear': FlatMapping({'b': 51, 'w': 3})}), 'q_targ': FlatMapping({'linear': FlatMapping({'b': 51, 'w': 3})})}), ({'pi_targ': FlatMapping({}), 'q': FlatMapping({}), 'q_targ': FlatMapping({})}, {'pi_targ': FlatMapping({}), 'q': FlatMapping({}), 'q_targ': FlatMapping({})}))

I think you'll have to vmap twice, once for q_targ_list and once for pi_targ_list.

frederikschubert commented 2 years ago

I found a workaround using this https://github.com/google/jax/issues/3102. However, now I have to select the correct target parameters using the q_targ_idx, which is a Tracer object created by vmap. I had to introduce a greedy_pi_targ parameter to the BaseTDLearningQWithTargetPolicy class because that is what SoftClippedDoubleQLearning is using now.

KristianHolsheimer commented 2 years ago

Here's a toy example of what I had in mind:

import jax
import numpy as np

def f(params, x):
    y = params['q_targ'] * x
    z = params['p_targ']
    return y * z

x = 11
params = {'q_targ': 13, 'p_targ': 7}
f(params, x)  # shape: ()

# vectorize over q_targ_list
vf =  jax.vmap(f, ({'q_targ': 0, 'p_targ': None}, None), 0, 'q_targ')
vparams = {'q_targ': np.array([13, 17, 19]), 'p_targ': 5}
vf(vparams, x)  # shape: (len(q_targ_list),)

# vectorize over pi_targ_list
vvf = jax.vmap(vf, ({'q_targ': None, 'p_targ': 0}, None), 0, 'p_targ')
vvparams = {'q_targ': np.array([13, 17, 19]), 'p_targ': np.array([5, 7])}
vvf(vvparams, x)  # shape: (len(pi_targ_list), len(q_targ_list))

As for the state and params, we'll need to stack them separately, e.g.

@property
def target_params(self):
    return hk.data_structures.to_immutable_dict({
        'q': self.q.params,
        'q_targ': jax.tree_multimap(
            lambda *xs: jnp.stack(xs, axis=0), *(q.params for q in self.q_targ_list)),
        'pi_targ': jax.tree_multimap(
            lambda *xs: jnp.stack(xs, axis=0), *(pi.params for pi in self.pi_targ_list)),
    })

@property
def target_state(self):
    ...  # same idea

The thing that isn't great about this (performance-wise) is the jnp.stack outside of a jax.jit call.

frederikschubert commented 2 years ago

Thank you for taking the time to show me this example!

Here's a toy example of what I had in mind:

Yes, that looks more reasonable than my hacky approach. I got it to work now, but have to refactor it ...obviously. But the general structure is now visible, the tests should pass and the dsac example is learning.

It even seems that the code now is a bit faster than before. For SAC (simpler network) I got around to around 11 ms and now its at around 8ms. I wonder why this is the case...

frederikschubert commented 2 years ago
(coax) ~/projects/coax on ilmarinen.tnt.uni-hannover.de
❯ xvfb-run python /home/schubert/projects/coax/doc/examples/pendulum/dsac.py
[dsac|absl|INFO] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
[dsac|absl|INFO] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
[dsac|TrainMonitor|INFO] ep: 1, T: 201, G: -1.53e+03,   avg_r: -7.65,   avg_G: -1.53e+03,       t: 200, dt: 7.920ms
[dsac|TrainMonitor|INFO] ep: 2, T: 402, G: -979,        avg_r: -4.9,    avg_G: -1.25e+03,       t: 200, dt: 4.962ms
[dsac|TrainMonitor|INFO] ep: 3, T: 603, G: -758,        avg_r: -3.79,   avg_G: -1.09e+03,       t: 200, dt: 5.308ms
[dsac|TrainMonitor|INFO] ep: 4, T: 804, G: -1.66e+03,   avg_r: -8.31,   avg_G: -1.23e+03,       t: 200, dt: 4.933ms
[dsac|TrainMonitor|INFO] ep: 5, T: 1,005,       G: -1.41e+03,   avg_r: -7.07,   avg_G: -1.27e+03,       t: 200, dt: 4.955ms
[dsac|TrainMonitor|INFO] ep: 6, T: 1,206,       G: -888,        avg_r: -4.44,   avg_G: -1.2e+03,        t: 200, dt: 4.949ms
[dsac|TrainMonitor|INFO] ep: 7, T: 1,407,       G: -699,        avg_r: -3.49,   avg_G: -1.13e+03,       t: 200, dt: 4.903ms
[dsac|TrainMonitor|INFO] ep: 8, T: 1,608,       G: -895,        avg_r: -4.47,   avg_G: -1.1e+03,        t: 200, dt: 4.869ms
[dsac|TrainMonitor|INFO] ep: 9, T: 1,809,       G: -1.69e+03,   avg_r: -8.43,   avg_G: -1.17e+03,       t: 200, dt: 4.966ms
[dsac|TrainMonitor|INFO] ep: 10,        T: 2,010,       G: -1.8e+03,    avg_r: -8.98,   avg_G: -1.23e+03,       t: 200, dt: 4.993ms
[dsac|TrainMonitor|INFO] ep: 11,        T: 2,211,       G: -1.47e+03,   avg_r: -7.34,   avg_G: -1.25e+03,       t: 200, dt: 5.066ms
[dsac|TrainMonitor|INFO] ep: 12,        T: 2,412,       G: -1.48e+03,   avg_r: -7.4,    avg_G: -1.28e+03,       t: 200, dt: 4.989ms
[dsac|TrainMonitor|INFO] ep: 13,        T: 2,613,       G: -1.19e+03,   avg_r: -5.94,   avg_G: -1.27e+03,       t: 200, dt: 4.994ms
[dsac|TrainMonitor|INFO] ep: 14,        T: 2,814,       G: -1.36e+03,   avg_r: -6.8,    avg_G: -1.28e+03,       t: 200, dt: 5.084ms
[dsac|TrainMonitor|INFO] ep: 15,        T: 3,015,       G: -1.38e+03,   avg_r: -6.9,    avg_G: -1.29e+03,       t: 200, dt: 4.978ms
[dsac|TrainMonitor|INFO] ep: 16,        T: 3,216,       G: -971,        avg_r: -4.86,   avg_G: -1.26e+03,       t: 200, dt: 4.898ms
[dsac|TrainMonitor|INFO] ep: 17,        T: 3,417,       G: -1.06e+03,   avg_r: -5.31,   avg_G: -1.24e+03,       t: 200, dt: 4.955ms
[dsac|TrainMonitor|INFO] ep: 18,        T: 3,618,       G: -759,        avg_r: -3.79,   avg_G: -1.19e+03,       t: 200, dt: 5.112ms
[dsac|TrainMonitor|INFO] ep: 19,        T: 3,819,       G: -999,        avg_r: -4.99,   avg_G: -1.17e+03,       t: 200, dt: 5.148ms
[dsac|TrainMonitor|INFO] ep: 20,        T: 4,020,       G: -1.66e+03,   avg_r: -8.3,    avg_G: -1.22e+03,       t: 200, dt: 4.893ms
[dsac|TrainMonitor|INFO] ep: 21,        T: 4,221,       G: -1.08e+03,   avg_r: -5.41,   avg_G: -1.21e+03,       t: 200, dt: 4.952ms
[dsac|TrainMonitor|INFO] ep: 22,        T: 4,422,       G: -1.46e+03,   avg_r: -7.3,    avg_G: -1.23e+03,       t: 200, dt: 4.868ms
[dsac|TrainMonitor|INFO] ep: 23,        T: 4,623,       G: -990,        avg_r: -4.95,   avg_G: -1.21e+03,       t: 200, dt: 5.005ms
[dsac|TrainMonitor|INFO] ep: 24,        T: 4,824,       G: -1.78e+03,   avg_r: -8.92,   avg_G: -1.26e+03,       t: 200, dt: 5.118ms
[dsac|TrainMonitor|INFO] ep: 25,        T: 5,025,       G: -916,        avg_r: -4.58,   avg_G: -1.23e+03,       t: 200, dt: 42.785ms,   SoftClippedDoubleQLearning/grads_max: 2.43,     SoftClippedDoubleQLearning/grads_norm: 3.21,    SoftClippedDoubleQLearning/loss: 3.93,      SoftClippedDoubleQLearning/td_error: -4.25,     SoftClippedDoubleQLearning/td_error_targ: 0
[dsac|TrainMonitor|INFO] ep: 26,        T: 5,226,       G: -1.21e+03,   avg_r: -6.03,   avg_G: -1.23e+03,       t: 200, dt: 47.650ms,   SoftClippedDoubleQLearning/grads_max: 1.8,      SoftClippedDoubleQLearning/grads_norm: 2.51,    SoftClippedDoubleQLearning/loss: 3.42,      SoftClippedDoubleQLearning/td_error: -4.03,     SoftClippedDoubleQLearning/td_error_targ: -0.171
[dsac|TrainMonitor|INFO] ep: 27,        T: 5,427,       G: -889,        avg_r: -4.44,   avg_G: -1.19e+03,       t: 200, dt: 15.969ms,   SoftClippedDoubleQLearning/grads_max: 1.37,     SoftClippedDoubleQLearning/grads_norm: 2.15,    SoftClippedDoubleQLearning/loss: 2.49,      SoftClippedDoubleQLearning/td_error: -3.31,     SoftClippedDoubleQLearning/td_error_targ: -0.899
[dsac|TrainMonitor|INFO] ep: 28,        T: 5,628,       G: -1.26e+03,   avg_r: -6.29,   avg_G: -1.2e+03,        t: 200, dt: 16.326ms,   SoftClippedDoubleQLearning/grads_max: 1.04,     SoftClippedDoubleQLearning/grads_norm: 1.86,    SoftClippedDoubleQLearning/loss: 1.8,       SoftClippedDoubleQLearning/td_error: -2.46,     SoftClippedDoubleQLearning/td_error_targ: -1.71
[dsac|TrainMonitor|INFO] ep: 29,        T: 5,829,       G: -1.07e+03,   avg_r: -5.34,   avg_G: -1.19e+03,       t: 200, dt: 16.240ms,   SoftClippedDoubleQLearning/grads_max: 0.79,     SoftClippedDoubleQLearning/grads_norm: 1.58,    SoftClippedDoubleQLearning/loss: 1.34,      SoftClippedDoubleQLearning/td_error: -1.62,     SoftClippedDoubleQLearning/td_error_targ: -2.34
[dsac|TrainMonitor|INFO] ep: 30,        T: 6,030,       G: -936,        avg_r: -4.68,   avg_G: -1.16e+03,       t: 200, dt: 16.112ms,   SoftClippedDoubleQLearning/grads_max: 0.628,    SoftClippedDoubleQLearning/grads_norm: 1.36,    SoftClippedDoubleQLearning/loss: 1.08,      SoftClippedDoubleQLearning/td_error: -0.987,    SoftClippedDoubleQLearning/td_error_targ: -2.65
[dsac|TrainMonitor|INFO] ep: 31,        T: 6,231,       G: -1.71e+03,   avg_r: -8.56,   avg_G: -1.22e+03,       t: 200, dt: 15.815ms,   SoftClippedDoubleQLearning/grads_max: 0.535,    SoftClippedDoubleQLearning/grads_norm: 1.22,    SoftClippedDoubleQLearning/loss: 0.941,     SoftClippedDoubleQLearning/td_error: -0.611,    SoftClippedDoubleQLearning/td_error_targ: -2.67
[dsac|TrainMonitor|INFO] ep: 32,        T: 6,432,       G: -1.64e+03,   avg_r: -8.18,   avg_G: -1.26e+03,       t: 200, dt: 16.300ms,   SoftClippedDoubleQLearning/grads_max: 0.478,    SoftClippedDoubleQLearning/grads_norm: 1.11,    SoftClippedDoubleQLearning/loss: 0.863,     SoftClippedDoubleQLearning/td_error: -0.367,    SoftClippedDoubleQLearning/td_error_targ: -2.49
[dsac|TrainMonitor|INFO] ep: 33,        T: 6,633,       G: -1.23e+03,   avg_r: -6.16,   avg_G: -1.26e+03,       t: 200, dt: 16.151ms,   SoftClippedDoubleQLearning/grads_max: 0.444,    SoftClippedDoubleQLearning/grads_norm: 1.05,    SoftClippedDoubleQLearning/loss: 0.814,     SoftClippedDoubleQLearning/td_error: -0.237,    SoftClippedDoubleQLearning/td_error_targ: -2.2
[dsac|TrainMonitor|INFO] ep: 34,        T: 6,834,       G: -879,        avg_r: -4.4,    avg_G: -1.22e+03,       t: 200, dt: 16.293ms,   SoftClippedDoubleQLearning/grads_max: 0.423,    SoftClippedDoubleQLearning/grads_norm: 1.01,    SoftClippedDoubleQLearning/loss: 0.793,     SoftClippedDoubleQLearning/td_error: -0.164,    SoftClippedDoubleQLearning/td_error_targ: -1.88
[dsac|TrainMonitor|INFO] ep: 35,        T: 7,035,       G: -768,        avg_r: -3.84,   avg_G: -1.17e+03,       t: 200, dt: 16.244ms,   SoftClippedDoubleQLearning/grads_max: 0.438,    SoftClippedDoubleQLearning/grads_norm: 1.05,    SoftClippedDoubleQLearning/loss: 0.826,     SoftClippedDoubleQLearning/td_error: -0.0973,   SoftClippedDoubleQLearning/td_error_targ: -1.51
[dsac|TrainMonitor|INFO] ep: 36,        T: 7,236,       G: -1.43e+03,   avg_r: -7.17,   avg_G: -1.2e+03,        t: 200, dt: 15.914ms,   SoftClippedDoubleQLearning/grads_max: 0.457,    SoftClippedDoubleQLearning/grads_norm: 1.09,    SoftClippedDoubleQLearning/loss: 0.84,      SoftClippedDoubleQLearning/td_error: -0.0822,   SoftClippedDoubleQLearning/td_error_targ: -1.21
[dsac|TrainMonitor|INFO] ep: 37,        T: 7,437,       G: -1.15e+03,   avg_r: -5.75,   avg_G: -1.19e+03,       t: 200, dt: 15.901ms,   SoftClippedDoubleQLearning/grads_max: 0.435,    SoftClippedDoubleQLearning/grads_norm: 1.04,    SoftClippedDoubleQLearning/loss: 0.796,     SoftClippedDoubleQLearning/td_error: -0.0736,   SoftClippedDoubleQLearning/td_error_targ: -0.953
[dsac|TrainMonitor|INFO] ep: 38,        T: 7,638,       G: -879,        avg_r: -4.39,   avg_G: -1.16e+03,       t: 200, dt: 30.564ms,   SoftClippedDoubleQLearning/grads_max: 0.428,    SoftClippedDoubleQLearning/grads_norm: 1.04,    SoftClippedDoubleQLearning/loss: 0.788,     SoftClippedDoubleQLearning/td_error: -0.0607,   SoftClippedDoubleQLearning/td_error_targ: -0.759,       EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 1.15,       SoftPG/grads_max: 0.27, SoftPG/grads_norm: 0.432,   SoftPG/kl_div_old: 1.4, SoftPG/loss: 8.54,      SoftPG/loss_bare: 8.77
[dsac|TrainMonitor|INFO] ep: 39,        T: 7,839,       G: -1.76e+03,   avg_r: -8.8,    avg_G: -1.22e+03,       t: 200, dt: 17.507ms,   SoftClippedDoubleQLearning/grads_max: 0.432,    SoftClippedDoubleQLearning/grads_norm: 1.05,    SoftClippedDoubleQLearning/loss: 0.798,     SoftClippedDoubleQLearning/td_error: -0.0452,   SoftClippedDoubleQLearning/td_error_targ: -0.593,       EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 1.02,       SoftPG/grads_max: 0.407,        SoftPG/grads_norm: 0.783,   SoftPG/kl_div_old: 1.25,        SoftPG/loss: 8.65,      SoftPG/loss_bare: 8.86
[dsac|TrainMonitor|INFO] ep: 40,        T: 8,040,       G: -1.65e+03,   avg_r: -8.25,   avg_G: -1.27e+03,       t: 200, dt: 17.767ms,   SoftClippedDoubleQLearning/grads_max: 0.43,     SoftClippedDoubleQLearning/grads_norm: 1.05,    SoftClippedDoubleQLearning/loss: 0.789,     SoftClippedDoubleQLearning/td_error: -0.0393,   SoftClippedDoubleQLearning/td_error_targ: -0.472,       EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.548,      SoftPG/grads_max: 0.629,        SoftPG/grads_norm: 1.19,    SoftPG/kl_div_old: 0.677,       SoftPG/loss: 8.62,      SoftPG/loss_bare: 8.73
[dsac|TrainMonitor|INFO] ep: 41,        T: 8,241,       G: -1.64e+03,   avg_r: -8.22,   avg_G: -1.3e+03,        t: 200, dt: 17.708ms,   SoftClippedDoubleQLearning/grads_max: 0.413,    SoftClippedDoubleQLearning/grads_norm: 1.02,    SoftClippedDoubleQLearning/loss: 0.77,      SoftClippedDoubleQLearning/td_error: -0.0282,   SoftClippedDoubleQLearning/td_error_targ: -0.388,       EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.154,      SoftPG/grads_max: 0.651,        SoftPG/grads_norm: 1.14,    SoftPG/kl_div_old: 0.186,       SoftPG/loss: 8.7,       SoftPG/loss_bare: 8.73
[dsac|TrainMonitor|INFO] ep: 42,        T: 8,442,       G: -799,        avg_r: -4,      avg_G: -1.25e+03,       t: 200, dt: 18.241ms,   SoftClippedDoubleQLearning/grads_max: 0.401,    SoftClippedDoubleQLearning/grads_norm: 0.999,   SoftClippedDoubleQLearning/loss: 0.769,     SoftClippedDoubleQLearning/td_error: -0.0266,   SoftClippedDoubleQLearning/td_error_targ: -0.311,       EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.129,      SoftPG/grads_max: 0.449,        SoftPG/grads_norm: 0.772,   SoftPG/kl_div_old: 0.147,       SoftPG/loss: 8.83,      SoftPG/loss_bare: 8.85
[dsac|TrainMonitor|INFO] ep: 43,        T: 8,643,       G: -1.17e+03,   avg_r: -5.84,   avg_G: -1.24e+03,       t: 200, dt: 17.834ms,   SoftClippedDoubleQLearning/grads_max: 0.402,    SoftClippedDoubleQLearning/grads_norm: 1.01,    SoftClippedDoubleQLearning/loss: 0.77,      SoftClippedDoubleQLearning/td_error: -0.0296,   SoftClippedDoubleQLearning/td_error_targ: -0.251,       EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.195,      SoftPG/grads_max: 0.487,        SoftPG/grads_norm: 0.84,    SoftPG/kl_div_old: 0.24,        SoftPG/loss: 8.95,      SoftPG/loss_bare: 8.99
[dsac|TrainMonitor|INFO] ep: 44,        T: 8,844,       G: -1.37e+03,   avg_r: -6.87,   avg_G: -1.26e+03,       t: 200, dt: 18.037ms,   SoftClippedDoubleQLearning/grads_max: 0.393,    SoftClippedDoubleQLearning/grads_norm: 0.994,   SoftClippedDoubleQLearning/loss: 0.759,     SoftClippedDoubleQLearning/td_error: -0.0255,   SoftClippedDoubleQLearning/td_error_targ: -0.208,       EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.243,      SoftPG/grads_max: 0.468,        SoftPG/grads_norm: 0.84,    SoftPG/kl_div_old: 0.309,       SoftPG/loss: 9.04,      SoftPG/loss_bare: 9.09
[dsac|TrainMonitor|INFO] ep: 45,        T: 9,045,       G: -857,        avg_r: -4.29,   avg_G: -1.22e+03,       t: 200, dt: 18.677ms,   SoftClippedDoubleQLearning/grads_max: 0.395,    SoftClippedDoubleQLearning/grads_norm: 1,       SoftClippedDoubleQLearning/loss: 0.761,     SoftClippedDoubleQLearning/td_error: -0.0199,   SoftClippedDoubleQLearning/td_error_targ: -0.173,       EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.306,      SoftPG/grads_max: 0.362,        SoftPG/grads_norm: 0.63,    SoftPG/kl_div_old: 0.386,       SoftPG/loss: 9.11,      SoftPG/loss_bare: 9.17
[dsac|TrainMonitor|INFO] ep: 46,        T: 9,246,       G: -1.66e+03,   avg_r: -8.29,   avg_G: -1.26e+03,       t: 200, dt: 18.522ms,   SoftClippedDoubleQLearning/grads_max: 0.401,    SoftClippedDoubleQLearning/grads_norm: 1.02,    SoftClippedDoubleQLearning/loss: 0.778,     SoftClippedDoubleQLearning/td_error: -0.0136,   SoftClippedDoubleQLearning/td_error_targ: -0.14,        EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.342,      SoftPG/grads_max: 0.293,        SoftPG/grads_norm: 0.491,   SoftPG/kl_div_old: 0.439,       SoftPG/loss: 9.17,      SoftPG/loss_bare: 9.24
[dsac|TrainMonitor|INFO] ep: 47,        T: 9,447,       G: -1.69e+03,   avg_r: -8.45,   avg_G: -1.3e+03,        t: 200, dt: 18.874ms,   SoftClippedDoubleQLearning/grads_max: 0.392,    SoftClippedDoubleQLearning/grads_norm: 1,       SoftClippedDoubleQLearning/loss: 0.757,     SoftClippedDoubleQLearning/td_error: -0.0188,   SoftClippedDoubleQLearning/td_error_targ: -0.116,       EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.421,      SoftPG/grads_max: 0.297,        SoftPG/grads_norm: 0.51,    SoftPG/kl_div_old: 0.525,       SoftPG/loss: 9.21,      SoftPG/loss_bare: 9.29
[dsac|TrainMonitor|INFO] ep: 48,        T: 9,648,       G: -1.49e+03,   avg_r: -7.46,   avg_G: -1.32e+03,       t: 200, dt: 18.156ms,   SoftClippedDoubleQLearning/grads_max: 0.392,    SoftClippedDoubleQLearning/grads_norm: 1.01,    SoftClippedDoubleQLearning/loss: 0.757,     SoftClippedDoubleQLearning/td_error: -0.0144,   SoftClippedDoubleQLearning/td_error_targ: -0.0985,      EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.48,       SoftPG/grads_max: 0.278,        SoftPG/grads_norm: 0.5,     SoftPG/kl_div_old: 0.587,       SoftPG/loss: 9.24,      SoftPG/loss_bare: 9.33
[dsac|TrainMonitor|INFO] ep: 49,        T: 9,849,       G: -1.81e+03,   avg_r: -9.06,   avg_G: -1.37e+03,       t: 200, dt: 17.650ms,   SoftClippedDoubleQLearning/grads_max: 0.373,    SoftClippedDoubleQLearning/grads_norm: 0.965,   SoftClippedDoubleQLearning/loss: 0.719,     SoftClippedDoubleQLearning/td_error: -0.0183,   SoftClippedDoubleQLearning/td_error_targ: -0.092,       EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.555,      SoftPG/grads_max: 0.255,        SoftPG/grads_norm: 0.436,   SoftPG/kl_div_old: 0.687,       SoftPG/loss: 9.27,      SoftPG/loss_bare: 9.38
[dsac|generate_gif|INFO] recorded episode to: ./data/gifs/dsac/T00010000.gif
[dsac|TrainMonitor|INFO] ep: 50,        T: 10,050,      G: -625,        avg_r: -3.12,   avg_G: -1.3e+03,        t: 200, dt: 45.271ms,   SoftClippedDoubleQLearning/grads_max: 0.378,    SoftClippedDoubleQLearning/grads_norm: 0.985,   SoftClippedDoubleQLearning/loss: 0.749,     SoftClippedDoubleQLearning/td_error: 0.000823,  SoftClippedDoubleQLearning/td_error_targ: -0.0778,      EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.593,      SoftPG/grads_max: 0.23, SoftPG/grads_norm: 0.397,   SoftPG/kl_div_old: 0.717,       SoftPG/loss: 9.29,      SoftPG/loss_bare: 9.41
[dsac|TrainMonitor|INFO] ep: 51,        T: 10,251,      G: -1.39e+03,   avg_r: -6.96,   avg_G: -1.31e+03,       t: 200, dt: 17.901ms,   SoftClippedDoubleQLearning/grads_max: 0.386,    SoftClippedDoubleQLearning/grads_norm: 1.01,    SoftClippedDoubleQLearning/loss: 0.763,     SoftClippedDoubleQLearning/td_error: -0.00187,  SoftClippedDoubleQLearning/td_error_targ: -0.0488,      EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.646,      SoftPG/grads_max: 0.201,        SoftPG/grads_norm: 0.352,   SoftPG/kl_div_old: 0.796,       SoftPG/loss: 9.31,      SoftPG/loss_bare: 9.44
[dsac|TrainMonitor|INFO] ep: 52,        T: 10,452,      G: -1.43e+03,   avg_r: -7.16,   avg_G: -1.32e+03,       t: 200, dt: 17.584ms,   SoftClippedDoubleQLearning/grads_max: 0.374,    SoftClippedDoubleQLearning/grads_norm: 0.981,   SoftClippedDoubleQLearning/loss: 0.738,     SoftClippedDoubleQLearning/td_error: -0.00105,  SoftClippedDoubleQLearning/td_error_targ: -0.0499,      EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.701,      SoftPG/grads_max: 0.18, SoftPG/grads_norm: 0.3,     SoftPG/kl_div_old: 0.857,       SoftPG/loss: 9.32,      SoftPG/loss_bare: 9.46
[dsac|TrainMonitor|INFO] ep: 53,        T: 10,653,      G: -860,        avg_r: -4.3,    avg_G: -1.27e+03,       t: 200, dt: 17.879ms,   SoftClippedDoubleQLearning/grads_max: 0.376,    SoftClippedDoubleQLearning/grads_norm: 0.99,    SoftClippedDoubleQLearning/loss: 0.752,     SoftClippedDoubleQLearning/td_error: -0.000538, SoftClippedDoubleQLearning/td_error_targ: -0.0384,      EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.742,      SoftPG/grads_max: 0.144,        SoftPG/grads_norm: 0.254,   SoftPG/kl_div_old: 0.913,       SoftPG/loss: 9.33,      SoftPG/loss_bare: 9.48
[dsac|TrainMonitor|INFO] ep: 54,        T: 10,854,      G: -759,        avg_r: -3.79,   avg_G: -1.22e+03,       t: 200, dt: 17.867ms,   SoftClippedDoubleQLearning/grads_max: 0.384,    SoftClippedDoubleQLearning/grads_norm: 1.01,    SoftClippedDoubleQLearning/loss: 0.764,     SoftClippedDoubleQLearning/td_error: -0.00205,  SoftClippedDoubleQLearning/td_error_targ: -0.0196,      EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.771,      SoftPG/grads_max: 0.159,        SoftPG/grads_norm: 0.285,   SoftPG/kl_div_old: 0.931,       SoftPG/loss: 9.34,      SoftPG/loss_bare: 9.49
[dsac|TrainMonitor|INFO] ep: 55,        T: 11,055,      G: -1.19e+03,   avg_r: -5.94,   avg_G: -1.22e+03,       t: 200, dt: 17.724ms,   SoftClippedDoubleQLearning/grads_max: 0.383,    SoftClippedDoubleQLearning/grads_norm: 1.01,    SoftClippedDoubleQLearning/loss: 0.77,      SoftClippedDoubleQLearning/td_error: 0.00872,   SoftClippedDoubleQLearning/td_error_targ: -0.0199,      EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.796,      SoftPG/grads_max: 0.122,        SoftPG/grads_norm: 0.194,   SoftPG/kl_div_old: 0.958,       SoftPG/loss: 9.35,      SoftPG/loss_bare: 9.5
[dsac|TrainMonitor|INFO] ep: 56,        T: 11,256,      G: -1.61e+03,   avg_r: -8.04,   avg_G: -1.26e+03,       t: 200, dt: 17.715ms,   SoftClippedDoubleQLearning/grads_max: 0.391,    SoftClippedDoubleQLearning/grads_norm: 1.02,    SoftClippedDoubleQLearning/loss: 0.774,     SoftClippedDoubleQLearning/td_error: -0.000953, SoftClippedDoubleQLearning/td_error_targ: -0.0106,      EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.812,      SoftPG/grads_max: 0.138,        SoftPG/grads_norm: 0.243,   SoftPG/kl_div_old: 0.992,       SoftPG/loss: 9.35,      SoftPG/loss_bare: 9.51
[dsac|TrainMonitor|INFO] ep: 57,        T: 11,457,      G: -747,        avg_r: -3.74,   avg_G: -1.21e+03,       t: 200, dt: 17.830ms,   SoftClippedDoubleQLearning/grads_max: 0.38,     SoftClippedDoubleQLearning/grads_norm: 1,       SoftClippedDoubleQLearning/loss: 0.755,     SoftClippedDoubleQLearning/td_error: -0.00426,  SoftClippedDoubleQLearning/td_error_targ: -0.02,        EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.832,      SoftPG/grads_max: 0.148,        SoftPG/grads_norm: 0.258,   SoftPG/kl_div_old: 1.01,        SoftPG/loss: 9.35,      SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 58,        T: 11,658,      G: -765,        avg_r: -3.83,   avg_G: -1.16e+03,       t: 200, dt: 17.787ms,   SoftClippedDoubleQLearning/grads_max: 0.401,    SoftClippedDoubleQLearning/grads_norm: 1.05,    SoftClippedDoubleQLearning/loss: 0.808,     SoftClippedDoubleQLearning/td_error: 0.0121,    SoftClippedDoubleQLearning/td_error_targ: 0.00259,      EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.862,      SoftPG/grads_max: 0.126,        SoftPG/grads_norm: 0.219,   SoftPG/kl_div_old: 1.05,        SoftPG/loss: 9.35,      SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 59,        T: 11,859,      G: -1.27e+03,   avg_r: -6.35,   avg_G: -1.17e+03,       t: 200, dt: 17.920ms,   SoftClippedDoubleQLearning/grads_max: 0.402,    SoftClippedDoubleQLearning/grads_norm: 1.05,    SoftClippedDoubleQLearning/loss: 0.789,     SoftClippedDoubleQLearning/td_error: -0.018,    SoftClippedDoubleQLearning/td_error_targ: 0.00638,      EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.874,      SoftPG/grads_max: 0.114,        SoftPG/grads_norm: 0.202,   SoftPG/kl_div_old: 1.06,        SoftPG/loss: 9.35,      SoftPG/loss_bare: 9.53
[dsac|TrainMonitor|INFO] ep: 60,        T: 12,060,      G: -1.01e+03,   avg_r: -5.05,   avg_G: -1.16e+03,       t: 200, dt: 17.924ms,   SoftClippedDoubleQLearning/grads_max: 0.4,      SoftClippedDoubleQLearning/grads_norm: 1.05,    SoftClippedDoubleQLearning/loss: 0.798,    SoftClippedDoubleQLearning/td_error: 0.00773,   SoftClippedDoubleQLearning/td_error_targ: 9.54e-05,     EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.87,       SoftPG/grads_max: 0.125,        SoftPG/grads_norm: 0.216,  SoftPG/kl_div_old: 1.07,        SoftPG/loss: 9.36,      SoftPG/loss_bare: 9.53
[dsac|TrainMonitor|INFO] ep: 61,        T: 12,261,      G: -1.55e+03,   avg_r: -7.74,   avg_G: -1.2e+03,        t: 200, dt: 18.050ms,   SoftClippedDoubleQLearning/grads_max: 0.395,    SoftClippedDoubleQLearning/grads_norm: 1.03,    SoftClippedDoubleQLearning/loss: 0.788, SoftClippedDoubleQLearning/td_error: 0.00251,   SoftClippedDoubleQLearning/td_error_targ: -2.04e-05,    EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.871, SoftPG/grads_max: 0.133, SoftPG/grads_norm: 0.223,       SoftPG/kl_div_old: 1.07,        SoftPG/loss: 9.35,      SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 62,        T: 12,462,      G: -856,        avg_r: -4.28,   avg_G: -1.16e+03,       t: 200, dt: 17.947ms,   SoftClippedDoubleQLearning/grads_max: 0.393,    SoftClippedDoubleQLearning/grads_norm: 1.03,    SoftClippedDoubleQLearning/loss: 0.794, SoftClippedDoubleQLearning/td_error: 0.000313,  SoftClippedDoubleQLearning/td_error_targ: -0.00501,     EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.864, SoftPG/grads_max: 0.118, SoftPG/grads_norm: 0.204,       SoftPG/kl_div_old: 1.06,        SoftPG/loss: 9.35,      SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 63,        T: 12,663,      G: -969,        avg_r: -4.84,   avg_G: -1.14e+03,       t: 200, dt: 17.778ms,   SoftClippedDoubleQLearning/grads_max: 0.401,    SoftClippedDoubleQLearning/grads_norm: 1.05,    SoftClippedDoubleQLearning/loss: 0.806, SoftClippedDoubleQLearning/td_error: -0.0033,   SoftClippedDoubleQLearning/td_error_targ: 0.00958,      EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.872, SoftPG/grads_max: 0.122, SoftPG/grads_norm: 0.216,       SoftPG/kl_div_old: 1.08,        SoftPG/loss: 9.35,      SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 64,        T: 12,864,      G: -1.24e+03,   avg_r: -6.2,    avg_G: -1.15e+03,       t: 200, dt: 17.919ms,   SoftClippedDoubleQLearning/grads_max: 0.389,    SoftClippedDoubleQLearning/grads_norm: 1.03,    SoftClippedDoubleQLearning/loss: 0.795, SoftClippedDoubleQLearning/td_error: 0.00509,   SoftClippedDoubleQLearning/td_error_targ: -0.00353,     EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.861, SoftPG/grads_max: 0.142, SoftPG/grads_norm: 0.238,       SoftPG/kl_div_old: 1.05,        SoftPG/loss: 9.35,      SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 65,        T: 13,065,      G: -879,        avg_r: -4.39,   avg_G: -1.12e+03,       t: 200, dt: 17.753ms,   SoftClippedDoubleQLearning/grads_max: 0.399,    SoftClippedDoubleQLearning/grads_norm: 1.04,    SoftClippedDoubleQLearning/loss: 0.807, SoftClippedDoubleQLearning/td_error: -0.000263, SoftClippedDoubleQLearning/td_error_targ: 0.0021,       EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.88,  SoftPG/grads_max: 0.12,  SoftPG/grads_norm: 0.206,       SoftPG/kl_div_old: 1.08,        SoftPG/loss: 9.35,      SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 66,        T: 13,266,      G: -1.87e+03,   avg_r: -9.35,   avg_G: -1.2e+03,        t: 200, dt: 18.146ms,   SoftClippedDoubleQLearning/grads_max: 0.412,    SoftClippedDoubleQLearning/grads_norm: 1.07,    SoftClippedDoubleQLearning/loss: 0.811, SoftClippedDoubleQLearning/td_error: -0.0122,   SoftClippedDoubleQLearning/td_error_targ: 0.00878,      EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.874, SoftPG/grads_max: 0.122, SoftPG/grads_norm: 0.223,       SoftPG/kl_div_old: 1.07,        SoftPG/loss: 9.35,      SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 67,        T: 13,467,      G: -968,        avg_r: -4.84,   avg_G: -1.18e+03,       t: 200, dt: 17.884ms,   SoftClippedDoubleQLearning/grads_max: 0.404,    SoftClippedDoubleQLearning/grads_norm: 1.05,    SoftClippedDoubleQLearning/loss: 0.808, SoftClippedDoubleQLearning/td_error: 0.00291,   SoftClippedDoubleQLearning/td_error_targ: -0.00665,     EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.872, SoftPG/grads_max: 0.129, SoftPG/grads_norm: 0.225,       SoftPG/kl_div_old: 1.06,        SoftPG/loss: 9.35,      SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 68,        T: 13,668,      G: -626,        avg_r: -3.13,   avg_G: -1.12e+03,       t: 200, dt: 17.906ms,   SoftClippedDoubleQLearning/grads_max: 0.405,    SoftClippedDoubleQLearning/grads_norm: 1.05,    SoftClippedDoubleQLearning/loss: 0.807, SoftClippedDoubleQLearning/td_error: -0.000917, SoftClippedDoubleQLearning/td_error_targ: 0.00247,      EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.881, SoftPG/grads_max: 0.149, SoftPG/grads_norm: 0.272,       SoftPG/kl_div_old: 1.08,        SoftPG/loss: 9.35,      SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 69,        T: 13,869,      G: -1.62e+03,   avg_r: -8.1,    avg_G: -1.17e+03,       t: 200, dt: 17.661ms,   SoftClippedDoubleQLearning/grads_max: 0.423,    SoftClippedDoubleQLearning/grads_norm: 1.1,     SoftClippedDoubleQLearning/loss: 0.841, SoftClippedDoubleQLearning/td_error: 0.00232,   SoftClippedDoubleQLearning/td_error_targ: 0.0128,       EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.885, SoftPG/grads_max: 0.155, SoftPG/grads_norm: 0.282,       SoftPG/kl_div_old: 1.07,        SoftPG/loss: 9.35,      SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 70,        T: 14,070,      G: -1.33e+03,   avg_r: -6.63,   avg_G: -1.19e+03,       t: 200, dt: 17.895ms,   SoftClippedDoubleQLearning/grads_max: 0.412,    SoftClippedDoubleQLearning/grads_norm: 1.07,    SoftClippedDoubleQLearning/loss: 0.816, SoftClippedDoubleQLearning/td_error: -0.00507,  SoftClippedDoubleQLearning/td_error_targ: 0.00611,       EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.885,      SoftPG/grads_max: 0.153,        SoftPG/grads_norm: 0.277,       SoftPG/kl_div_old: 1.09,        SoftPG/loss: 9.35,      SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 71,        T: 14,271,      G: -858,        avg_r: -4.29,   avg_G: -1.15e+03,       t: 200, dt: 17.890ms,   SoftClippedDoubleQLearning/grads_max: 0.417,    SoftClippedDoubleQLearning/grads_norm: 1.08,    SoftClippedDoubleQLearning/loss: 0.837, SoftClippedDoubleQLearning/td_error: -0.00151,  SoftClippedDoubleQLearning/td_error_targ: 0.00969,       EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.89,       SoftPG/grads_max: 0.143,        SoftPG/grads_norm: 0.243,       SoftPG/kl_div_old: 1.08,        SoftPG/loss: 9.34,      SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 72,        T: 14,472,      G: -1.64e+03,   avg_r: -8.18,   avg_G: -1.2e+03,        t: 200, dt: 17.939ms,   SoftClippedDoubleQLearning/grads_max: 0.418,    SoftClippedDoubleQLearning/grads_norm: 1.09,    SoftClippedDoubleQLearning/loss: 0.828, SoftClippedDoubleQLearning/td_error: -0.00188,  SoftClippedDoubleQLearning/td_error_targ: 0.00845,       EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.916,      SoftPG/grads_max: 0.126,        SoftPG/grads_norm: 0.234,       SoftPG/kl_div_old: 1.12,        SoftPG/loss: 9.34,      SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 73,        T: 14,673,      G: -1.17e+03,   avg_r: -5.85,   avg_G: -1.2e+03,        t: 200, dt: 17.870ms,   SoftClippedDoubleQLearning/grads_max: 0.402,    SoftClippedDoubleQLearning/grads_norm: 1.04,    SoftClippedDoubleQLearning/loss: 0.798, SoftClippedDoubleQLearning/td_error: -0.00402,  SoftClippedDoubleQLearning/td_error_targ: -0.00559,      EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.9,        SoftPG/grads_max: 0.159,        SoftPG/grads_norm: 0.282,       SoftPG/kl_div_old: 1.09,        SoftPG/loss: 9.34,      SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 74,        T: 14,874,      G: -1.49e+03,   avg_r: -7.45,   avg_G: -1.23e+03,       t: 200, dt: 17.755ms,   SoftClippedDoubleQLearning/grads_max: 0.401,    SoftClippedDoubleQLearning/grads_norm: 1.04,    SoftClippedDoubleQLearning/loss: 0.79,  SoftClippedDoubleQLearning/td_error: -0.00395,  SoftClippedDoubleQLearning/td_error_targ: -0.00844,      EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.919,      SoftPG/grads_max: 0.168,        SoftPG/grads_norm: 0.298,       SoftPG/kl_div_old: 1.11,        SoftPG/loss: 9.34,      SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 75,        T: 15,075,      G: -1.8e+03,    avg_r: -9.02,   avg_G: -1.29e+03,       t: 200, dt: 17.804ms,   SoftClippedDoubleQLearning/grads_max: 0.397,    SoftClippedDoubleQLearning/grads_norm: 1.03,    SoftClippedDoubleQLearning/loss: 0.79,  SoftClippedDoubleQLearning/td_error: 0.00114,   SoftClippedDoubleQLearning/td_error_targ: -0.0144,       EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.934,      SoftPG/grads_max: 0.177,        SoftPG/grads_norm: 0.306,       SoftPG/kl_div_old: 1.16,        SoftPG/loss: 9.34,      SoftPG/loss_bare: 9.53
[dsac|TrainMonitor|INFO] ep: 76,        T: 15,276,      G: -1.37e+03,   avg_r: -6.83,   avg_G: -1.29e+03,       t: 200, dt: 17.985ms,   SoftClippedDoubleQLearning/grads_max: 0.393,    SoftClippedDoubleQLearning/grads_norm: 1.02,    SoftClippedDoubleQLearning/loss: 0.783, SoftClippedDoubleQLearning/td_error: -0.00264,  SoftClippedDoubleQLearning/td_error_targ: -0.0129,       EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 0.956,      SoftPG/grads_max: 0.169,        SoftPG/grads_norm: 0.31,        SoftPG/kl_div_old: 1.16,        SoftPG/loss: 9.34,      SoftPG/loss_bare: 9.53
[dsac|TrainMonitor|INFO] ep: 77,        T: 15,477,      G: -743,        avg_r: -3.72,   avg_G: -1.24e+03,       t: 200, dt: 17.934ms,   SoftClippedDoubleQLearning/grads_max: 0.399,    SoftClippedDoubleQLearning/grads_norm: 1.04,    SoftClippedDoubleQLearning/loss: 0.794, SoftClippedDoubleQLearning/td_error: 0.00778,   SoftClippedDoubleQLearning/td_error_targ: -0.0117,       EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 1.01,       SoftPG/grads_max: 0.166,        SoftPG/grads_norm: 0.296,       SoftPG/kl_div_old: 1.22,        SoftPG/loss: 9.34,      SoftPG/loss_bare: 9.54
[dsac|TrainMonitor|INFO] ep: 78,        T: 15,678,      G: -1.37e+03,   avg_r: -6.83,   avg_G: -1.25e+03,       t: 200, dt: 17.847ms,   SoftClippedDoubleQLearning/grads_max: 0.406,    SoftClippedDoubleQLearning/grads_norm: 1.05,    SoftClippedDoubleQLearning/loss: 0.794, SoftClippedDoubleQLearning/td_error: -0.000648, SoftClippedDoubleQLearning/td_error_targ: 0.00351,       EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 1.03,       SoftPG/grads_max: 0.167,        SoftPG/grads_norm: 0.308,       SoftPG/kl_div_old: 1.24,        SoftPG/loss: 9.34,      SoftPG/loss_bare: 9.54
[dsac|TrainMonitor|INFO] ep: 79,        T: 15,879,      G: -1.48e+03,   avg_r: -7.39,   avg_G: -1.27e+03,       t: 200, dt: 17.797ms,   SoftClippedDoubleQLearning/grads_max: 0.398,    SoftClippedDoubleQLearning/grads_norm: 1.03,    SoftClippedDoubleQLearning/loss: 0.788, SoftClippedDoubleQLearning/td_error: -0.00115,  SoftClippedDoubleQLearning/td_error_targ: -0.00406,      EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 1.02,       SoftPG/grads_max: 0.141,        SoftPG/grads_norm: 0.252,       SoftPG/kl_div_old: 1.25,        SoftPG/loss: 9.35,      SoftPG/loss_bare: 9.55
[dsac|TrainMonitor|INFO] ep: 80,        T: 16,080,      G: -1.67e+03,   avg_r: -8.35,   avg_G: -1.31e+03,       t: 200, dt: 17.926ms,   SoftClippedDoubleQLearning/grads_max: 0.399,    SoftClippedDoubleQLearning/grads_norm: 1.03,    SoftClippedDoubleQLearning/loss: 0.777, SoftClippedDoubleQLearning/td_error: -0.00173,  SoftClippedDoubleQLearning/td_error_targ: -0.00815,      EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 1.04,       SoftPG/grads_max: 0.154,        SoftPG/grads_norm: 0.274,       SoftPG/kl_div_old: 1.27,        SoftPG/loss: 9.35,      SoftPG/loss_bare: 9.55
[dsac|TrainMonitor|INFO] ep: 81,        T: 16,281,      G: -860,        avg_r: -4.3,    avg_G: -1.27e+03,       t: 200, dt: 17.890ms,   SoftClippedDoubleQLearning/grads_max: 0.406,    SoftClippedDoubleQLearning/grads_norm: 1.05,    SoftClippedDoubleQLearning/loss: 0.789, SoftClippedDoubleQLearning/td_error: 0.00353,   SoftClippedDoubleQLearning/td_error_targ: -0.00365,      EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 1.05,       SoftPG/grads_max: 0.169,        SoftPG/grads_norm: 0.305,       SoftPG/kl_div_old: 1.27,        SoftPG/loss: 9.34,      SoftPG/loss_bare: 9.55
[dsac|TrainMonitor|INFO] ep: 82,        T: 16,482,      G: -1.07e+03,   avg_r: -5.35,   avg_G: -1.25e+03,       t: 200, dt: 17.795ms,   SoftClippedDoubleQLearning/grads_max: 0.41,     SoftClippedDoubleQLearning/grads_norm: 1.05,    SoftClippedDoubleQLearning/loss: 0.778, SoftClippedDoubleQLearning/td_error: -0.00925,  SoftClippedDoubleQLearning/td_error_targ: -0.00354,      EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 1.09,       SoftPG/grads_max: 0.16, SoftPG/grads_norm: 0.282,       SoftPG/kl_div_old: 1.32,        SoftPG/loss: 9.35,      SoftPG/loss_bare: 9.56
[dsac|TrainMonitor|INFO] ep: 83,        T: 16,683,      G: -962,        avg_r: -4.81,   avg_G: -1.22e+03,       t: 200, dt: 17.747ms,   SoftClippedDoubleQLearning/grads_max: 0.409,    SoftClippedDoubleQLearning/grads_norm: 1.05,    SoftClippedDoubleQLearning/loss: 0.786, SoftClippedDoubleQLearning/td_error: -0.00545,  SoftClippedDoubleQLearning/td_error_targ: -0.00268,      EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 1.11,       SoftPG/grads_max: 0.142,        SoftPG/grads_norm: 0.259,       SoftPG/kl_div_old: 1.36,        SoftPG/loss: 9.35,      SoftPG/loss_bare: 9.57
[dsac|TrainMonitor|INFO] ep: 84,        T: 16,884,      G: -1.56e+03,   avg_r: -7.79,   avg_G: -1.25e+03,       t: 200, dt: 17.797ms,   SoftClippedDoubleQLearning/grads_max: 0.407,    SoftClippedDoubleQLearning/grads_norm: 1.05,    SoftClippedDoubleQLearning/loss: 0.78,  SoftClippedDoubleQLearning/td_error: 0.00423,   SoftClippedDoubleQLearning/td_error_targ: -0.00923,      EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 1.13,       SoftPG/grads_max: 0.18, SoftPG/grads_norm: 0.324,       SoftPG/kl_div_old: 1.36,        SoftPG/loss: 9.35,      SoftPG/loss_bare: 9.57
[dsac|TrainMonitor|INFO] ep: 85,        T: 17,085,      G: -1.3e+03,    avg_r: -6.5,    avg_G: -1.26e+03,       t: 200, dt: 17.771ms,   SoftClippedDoubleQLearning/grads_max: 0.4,      SoftClippedDoubleQLearning/grads_norm: 1.03,    SoftClippedDoubleQLearning/loss: 0.765, SoftClippedDoubleQLearning/td_error: -9.46e-05, SoftClippedDoubleQLearning/td_error_targ: -0.008,        EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 1.18,       SoftPG/grads_max: 0.151,        SoftPG/grads_norm: 0.282,       SoftPG/kl_div_old: 1.45,        SoftPG/loss: 9.35,      SoftPG/loss_bare: 9.58
[dsac|TrainMonitor|INFO] ep: 86,        T: 17,286,      G: -1.49e+03,   avg_r: -7.46,   avg_G: -1.28e+03,       t: 200, dt: 17.933ms,   SoftClippedDoubleQLearning/grads_max: 0.401,    SoftClippedDoubleQLearning/grads_norm: 1.03,    SoftClippedDoubleQLearning/loss: 0.751, SoftClippedDoubleQLearning/td_error: -0.00512,  SoftClippedDoubleQLearning/td_error_targ: -0.00857,      EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 1.56,       SoftPG/grads_max: 0.132,        SoftPG/grads_norm: 0.259,       SoftPG/kl_div_old: 1.89,        SoftPG/loss: 9.33,      SoftPG/loss_bare: 9.64
[dsac|TrainMonitor|INFO] ep: 87,        T: 17,487,      G: -1.28e+03,   avg_r: -6.41,   avg_G: -1.28e+03,       t: 200, dt: 17.907ms,   SoftClippedDoubleQLearning/grads_max: 0.398,    SoftClippedDoubleQLearning/grads_norm: 1.03,    SoftClippedDoubleQLearning/loss: 0.731, SoftClippedDoubleQLearning/td_error: -0.00714,  SoftClippedDoubleQLearning/td_error_targ: -0.0154,       EntropyRegularizer/beta: 0.04,  EntropyRegularizer/entropy: 4.84,       SoftPG/grads_max: 0.5,  SoftPG/grads_norm: 0.819,       SoftPG/kl_div_old: 5.88,        SoftPG/loss: 8.85,      SoftPG/loss_bare: 9.81
Invalid nan value encountered in the output of a C++-jit/pmap function. Calling the de-optimized version.
Invalid value encountered in the output of a jit/pmap-ed function. Calling the de-optimized version.
Invalid value encountered in the output of a jit/pmap-ed function. Calling the de-optimized version.
Invalid value encountered in the output of a jit/pmap-ed function. Calling the de-optimized version.
Invalid value encountered in the output of a jit/pmap-ed function. Calling the de-optimized version.
Traceback (most recent call last):
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/api.py", line 131, in _nan_check_posthook
    xla.check_special(xla.xla_call_p, buffers)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 378, in check_special
    _check_special(name, buf.xla_shape(), buf)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 384, in _check_special
    raise FloatingPointError(f"invalid value (nan) encountered in {name}")
FloatingPointError: invalid value (nan) encountered in xla_call

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 584, in _xla_call_impl
    out = compiled_fun(*args)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 978, in _execute_compiled
    check_special(name, out_bufs)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 378, in check_special
    _check_special(name, buf.xla_shape(), buf)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 384, in _check_special
    raise FloatingPointError(f"invalid value (nan) encountered in {name}")
FloatingPointError: invalid value (nan) encountered in grads_and_metrics_func

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 584, in _xla_call_impl
    out = compiled_fun(*args)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 978, in _execute_compiled
    check_special(name, out_bufs)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 378, in check_special
    _check_special(name, buf.xla_shape(), buf)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 384, in _check_special
    raise FloatingPointError(f"invalid value (nan) encountered in {name}")
FloatingPointError: invalid value (nan) encountered in jvp(log_proba)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 584, in _xla_call_impl
    out = compiled_fun(*args)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 978, in _execute_compiled
    check_special(name, out_bufs)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 378, in check_special
    _check_special(name, buf.xla_shape(), buf)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 384, in _check_special
    raise FloatingPointError(f"invalid value (nan) encountered in {name}")
FloatingPointError: invalid value (nan) encountered in jvp(log_proba)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 584, in _xla_call_impl
    out = compiled_fun(*args)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 978, in _execute_compiled
    check_special(name, out_bufs)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 378, in check_special
    _check_special(name, buf.xla_shape(), buf)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 384, in _check_special
    raise FloatingPointError(f"invalid value (nan) encountered in {name}")
FloatingPointError: invalid value (nan) encountered in jvp(_einsum)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/schubert/projects/coax/doc/examples/pendulum/dsac.py", line 98, in <module>

  File "/home/schubert/projects/coax/coax/policy_objectives/_deterministic_pg.py", line 159, in update
    return super().update(transition_batch, None)
  File "/home/schubert/projects/coax/coax/policy_objectives/_base.py", line 164, in update
    grads, function_state, metrics = self.grads_and_metrics(transition_batch, Adv)
  File "/home/schubert/projects/coax/coax/policy_objectives/_deterministic_pg.py", line 193, in grads_and_metrics
    return super().grads_and_metrics(transition_batch, None)
  File "/home/schubert/projects/coax/coax/policy_objectives/_base.py", line 233, in grads_and_metrics
    return self._grad_and_metrics_func(
  File "/home/schubert/projects/coax/coax/utils/_jit.py", line 80, in __call__
    return self._jitted_func(*args, **kwargs)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/api.py", line 416, in cache_miss
    out_flat = xla.xla_call(
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1635, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 627, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 600, in _xla_call_impl
    _ = clone.call_wrapped(*args)  # probably won't return
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/schubert/projects/coax/coax/policy_objectives/_base.py", line 95, in grads_and_metrics_func
    grads_func(params, state, hyperparams, rng, transition_batch, Adv)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/api.py", line 924, in grad_f_aux
    (_, aux), g = value_and_grad_f(*args, **kwargs)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/api.py", line 995, in value_and_grad_f
    ans, vjp_py, aux = _vjp(
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/api.py", line 2317, in _vjp
    out_primal, out_vjp, aux = ad.vjp(
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/ad.py", line 118, in vjp
    out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/ad.py", line 103, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 513, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/schubert/projects/coax/coax/policy_objectives/_base.py", line 68, in loss_func
    self.objective_func(params, state, hyperparams, rng, transition_batch, Adv)
  File "/home/schubert/projects/coax/coax/policy_objectives/_soft_pg.py", line 45, in objective_func
    log_pi = self.pi.proba_dist.log_proba(dist_params, A)
  File "/home/schubert/projects/coax/coax/utils/_jit.py", line 80, in __call__
    return self._jitted_func(*args, **kwargs)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/api.py", line 416, in cache_miss
    out_flat = xla.xla_call(
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1635, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/ad.py", line 323, in process_call
    result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1635, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 202, in process_call
    jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 311, in partial_eval
    out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1635, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 627, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 600, in _xla_call_impl
    _ = clone.call_wrapped(*args)  # probably won't return
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/schubert/projects/coax/coax/proba_dists/_composite.py", line 137, in log_proba
    return self._structure.log_proba(dist_params, X)
  File "/home/schubert/projects/coax/coax/utils/_jit.py", line 80, in __call__
    return self._jitted_func(*args, **kwargs)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/api.py", line 416, in cache_miss
    out_flat = xla.xla_call(
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1635, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/ad.py", line 323, in process_call
    result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1635, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 202, in process_call
    jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 311, in partial_eval
    out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1635, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 627, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 600, in _xla_call_impl
    _ = clone.call_wrapped(*args)  # probably won't return
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/schubert/projects/coax/coax/proba_dists/_normal.py", line 127, in log_proba
    quadratic = jnp.einsum('ij,ij->i', jnp.square(X - mu), jnp.exp(-logvar))
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 4939, in einsum
    return _einsum(operands, contractions, precision)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/api.py", line 416, in cache_miss
    out_flat = xla.xla_call(
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1635, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/ad.py", line 323, in process_call
    result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1635, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 202, in process_call
    jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 311, in partial_eval
    out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1635, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 627, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 600, in _xla_call_impl
    _ = clone.call_wrapped(*args)  # probably won't return
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 5070, in _einsum
    operand = lax.dot_general(rhs, lhs, dimension_numbers, precision)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 730, in dot_general
    return dot_general_p.bind(lhs, rhs,
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 272, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/ad.py", line 288, in process_primitive
    primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/ad.py", line 465, in standard_jvp
    val_out = primitive.bind(*primals, **params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 272, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 151, in process_primitive
    return self.default_process_primitive(primitive, tracers, params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 158, in default_process_primitive
    return primitive.bind(*consts, **params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 272, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 624, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 313, in apply_primitive
    return compiled_fun(*args)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 337, in <lambda>
    return lambda *args, **kw: compiled(*args, **kw)[0]
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 978, in _execute_compiled
    check_special(name, out_bufs)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 378, in check_special
    _check_special(name, buf.xla_shape(), buf)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 384, in _check_special
    raise FloatingPointError(f"invalid value (nan) encountered in {name}")
jax._src.traceback_util.UnfilteredStackTrace: FloatingPointError: invalid value (nan) encountered in dot_general

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/schubert/projects/coax/doc/examples/pendulum/dsac.py", line 98, in <module>

  File "/home/schubert/projects/coax/coax/policy_objectives/_deterministic_pg.py", line 159, in update
    return super().update(transition_batch, None)
  File "/home/schubert/projects/coax/coax/policy_objectives/_base.py", line 164, in update
    grads, function_state, metrics = self.grads_and_metrics(transition_batch, Adv)
  File "/home/schubert/projects/coax/coax/policy_objectives/_deterministic_pg.py", line 193, in grads_and_metrics
    return super().grads_and_metrics(transition_batch, None)
  File "/home/schubert/projects/coax/coax/policy_objectives/_base.py", line 233, in grads_and_metrics
    return self._grad_and_metrics_func(
  File "/home/schubert/projects/coax/coax/utils/_jit.py", line 80, in __call__
    return self._jitted_func(*args, **kwargs)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/api.py", line 137, in _nan_check_posthook
    fun._cache_miss(*args, **kwargs)[0]  # probably won't return
  File "/home/schubert/projects/coax/coax/policy_objectives/_base.py", line 95, in grads_and_metrics_func
    grads_func(params, state, hyperparams, rng, transition_batch, Adv)
  File "/home/schubert/projects/coax/coax/policy_objectives/_base.py", line 68, in loss_func
    self.objective_func(params, state, hyperparams, rng, transition_batch, Adv)
  File "/home/schubert/projects/coax/coax/policy_objectives/_soft_pg.py", line 45, in objective_func
    log_pi = self.pi.proba_dist.log_proba(dist_params, A)
  File "/home/schubert/projects/coax/coax/utils/_jit.py", line 80, in __call__
    return self._jitted_func(*args, **kwargs)
  File "/home/schubert/projects/coax/coax/proba_dists/_composite.py", line 137, in log_proba
    return self._structure.log_proba(dist_params, X)
  File "/home/schubert/projects/coax/coax/utils/_jit.py", line 80, in __call__
    return self._jitted_func(*args, **kwargs)
  File "/home/schubert/projects/coax/coax/proba_dists/_normal.py", line 127, in log_proba
    quadratic = jnp.einsum('ij,ij->i', jnp.square(X - mu), jnp.exp(-logvar))
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 4939, in einsum
    return _einsum(operands, contractions, precision)
  File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 5070, in _einsum
    operand = lax.dot_general(rhs, lhs, dimension_numbers, precision)
FloatingPointError: invalid value (nan) encountered in dot_general
KristianHolsheimer commented 2 years ago

Looks like this is happening in Normal.log_proba, specifically in:

quadratic = jnp.einsum('ij,ij->i', jnp.square(X - mu), jnp.exp(-logvar))

Might be worth putting a breakpoint there.

Perhaps logvar is large and negative? If that's the case, it might mean that the entropy bonus is taking over the loss. In fact, that seems what the logs are telling us too.

Perhaps try a smaller entropy regularizer coefficient?

frederikschubert commented 2 years ago

Perhaps try a smaller entropy regularizer coefficient?

Yes, that definitely helps, but the entropy keeps rising even with a tiny coefficient. And it only happens with the StochasticQ functions. But also the learning performance of SAC is not the same as before, so I suspect that there still are some differences due to the refactoring.

frederikschubert commented 2 years ago

Ok, I verified the implementation through this test against the old implementation:

    def test_update_boxspace(self):
        env = self.env_boxspace
        func_q = self.func_q_type1
        func_pi = self.func_pi_boxspace
        transition_batch = self.transition_boxspace

        q1 = Q(func_q, env)
        q2 = Q(func_q, env)
        pi1 = Policy(func_pi, env)
        pi2 = Policy(func_pi, env)
        q_targ1 = q1.copy()
        q_targ2 = q2.copy()
        updater1 = ClippedDoubleQLearning(
            q1, pi_targ_list=[pi1, pi2], q_targ_list=[q_targ1, q_targ2], optimizer=sgd(1.0))
        updater2 = ClippedDoubleQLearning(
            q2, pi_targ_list=[pi1, pi2], q_targ_list=[q_targ1, q_targ2], optimizer=sgd(1.0))

        q1_old = deepcopy(q1)
        q2_old = deepcopy(q2)
        q_targ1_old = q1_old.copy()
        q_targ2_old = q2_old.copy()
        old_updater1 = OldClippedDoubleQLearning(
            q1_old, pi_targ_list=[pi1, pi2], q_targ_list=[q_targ1_old, q_targ2_old], optimizer=sgd(1.0))
        old_updater2 = OldClippedDoubleQLearning(
            q2_old, pi_targ_list=[pi1, pi2], q_targ_list=[q_targ1_old, q_targ2_old], optimizer=sgd(1.0))

        params1 = deepcopy(q1.params)
        params2 = deepcopy(q2.params)
        function_state1 = deepcopy(q1.function_state)
        function_state2 = deepcopy(q2.function_state)

        updater1.update(transition_batch)
        updater2.update(transition_batch)

        old_updater1.update(transition_batch)
        old_updater2.update(transition_batch)

        self.assertPytreeNotEqual(params1, q1.params)
        self.assertPytreeNotEqual(params2, q2.params)
        self.assertPytreeNotEqual(function_state1, q1.function_state)
        self.assertPytreeNotEqual(function_state2, q2.function_state)

        self.assertPytreeAlmostEqual(q1_old.params, q1.params)
        self.assertPytreeAlmostEqual(q2_old.params, q2.params)
        self.assertPytreeAlmostEqual(q1_old.function_state, q1.function_state)
        self.assertPytreeAlmostEqual(q2_old.function_state, q2.function_state)

So the problem is with the StochasticQ function or rather finding the right hyperparameters. Do we want to defer this until after this change has been merged. I can remove the dsac.py file as long as there is no working implementation.

frederikschubert commented 2 years ago

I think the tree structure of in/out_axes must match those of its input (up to the level you provide). So you'll probably have to transform them using hk.data_structures.to_haiku_dict() and include all keys, i.e. also 'q'.

Ok, so I did try to implement the class using haiku's to_immutable_dict method. But that won't work. The way that vmap treats None values in the in_axes specification is by replacing them with a special object value. This uses the leaves of the FlatMapping, which are in turn computed using jax.tree_flatten. This method discards the leaves with None value. Here is a quick demonstration what works:

def fun(d):
    return d['a'] + d['b']

vfun = jax.vmap(fun, in_axes=(hk.data_structures.to_immutable_dict({'a': 0, 'b': 0}),))
vfun(hk.data_structures.to_immutable_dict({'a': jnp.array([1, 2]), 'b': jnp.array([0, 3])}))

and what does not work due to the None values.

vfun = jax.vmap(fun, in_axes=(hk.data_structures.to_immutable_dict({'a': 0, 'b': None}),))
vfun(hk.data_structures.to_immutable_dict({'a': jnp.array([1, 2]), 'b': jnp.array([0, 3])}))

So I have to convert the parameters and state with to_mutable_dict before vmapping over them. Does this have any unwanted consequences @KristianHolsheimer

KristianHolsheimer commented 2 years ago

Hmmm, the deeper we get into it, the hackier it starts to feel. Also, from we may expect a slight performance drop (more memory, typically) by using vmap.

From the perspective of the jitted function, the double for loop is actually fine. The main problem is that ClippedDoubleQLearning doesn't make use of the same class inheritance as the other value-function updaters.. and that it looks too convoluted and ugly. But so far, the vmap implementation, although an elegant idea, starts to feel equally convoluted.

frederikschubert commented 2 years ago

Yes, I agree. I think it might be better to just add the stochastic q-function parts to ClippedDoubleQLearning and accept its current form.