Closed frederikschubert closed 2 years ago
Right now this throws an error due to the different sizes of the parameters and function states of the target q functions, e.g. when running examples/pendulum/dsac.py
.
ValueError: vmap got inconsistent sizes for array axes to be mapped:
the tree of axis sizes is:
(({'pi_targ': FlatMapping({
'linear': FlatMapping({'b': 8, 'w': 3}),
'linear_1': FlatMapping({'b': 8, 'w': 8}),
'linear_2': FlatMapping({'b': 8, 'w': 8}),
'linear_3': FlatMapping({'b': 2, 'w': 8}),
}), 'q': FlatMapping({'linear': FlatMapping({'b': 51, 'w': 3})}), 'q_targ': FlatMapping({'linear': FlatMapping({'b': 51, 'w': 3})})}, {'pi_targ': FlatMapping({
'linear': FlatMapping({'b': 8, 'w': 3}),
'linear_1': FlatMapping({'b': 8, 'w': 8}),
'linear_2': FlatMapping({'b': 8, 'w': 8}),
'linear_3': FlatMapping({'b': 2, 'w': 8}),
}), 'q': FlatMapping({'linear': FlatMapping({'b': 51, 'w': 3})}), 'q_targ': FlatMapping({'linear': FlatMapping({'b': 51, 'w': 3})})}), ({'pi_targ': FlatMapping({}), 'q': FlatMapping({}), 'q_targ': FlatMapping({})}, {'pi_targ': FlatMapping({}), 'q': FlatMapping({}), 'q_targ': FlatMapping({})}))
Right now this throws an error due to the different sizes of the parameters and function states of the target q functions, e.g. when running
examples/pendulum/dsac.py
.ValueError: vmap got inconsistent sizes for array axes to be mapped: the tree of axis sizes is: (({'pi_targ': FlatMapping({ 'linear': FlatMapping({'b': 8, 'w': 3}), 'linear_1': FlatMapping({'b': 8, 'w': 8}), 'linear_2': FlatMapping({'b': 8, 'w': 8}), 'linear_3': FlatMapping({'b': 2, 'w': 8}), }), 'q': FlatMapping({'linear': FlatMapping({'b': 51, 'w': 3})}), 'q_targ': FlatMapping({'linear': FlatMapping({'b': 51, 'w': 3})})}, {'pi_targ': FlatMapping({ 'linear': FlatMapping({'b': 8, 'w': 3}), 'linear_1': FlatMapping({'b': 8, 'w': 8}), 'linear_2': FlatMapping({'b': 8, 'w': 8}), 'linear_3': FlatMapping({'b': 2, 'w': 8}), }), 'q': FlatMapping({'linear': FlatMapping({'b': 51, 'w': 3})}), 'q_targ': FlatMapping({'linear': FlatMapping({'b': 51, 'w': 3})})}), ({'pi_targ': FlatMapping({}), 'q': FlatMapping({}), 'q_targ': FlatMapping({})}, {'pi_targ': FlatMapping({}), 'q': FlatMapping({}), 'q_targ': FlatMapping({})}))
I think you'll have to vmap twice, once for q_targ_list
and once for pi_targ_list
.
I found a workaround using this https://github.com/google/jax/issues/3102. However, now I have to select the correct target parameters using the q_targ_idx
, which is a Tracer
object created by vmap
. I had to introduce a greedy_pi_targ
parameter to the BaseTDLearningQWithTargetPolicy
class because that is what SoftClippedDoubleQLearning
is using now.
Here's a toy example of what I had in mind:
import jax
import numpy as np
def f(params, x):
y = params['q_targ'] * x
z = params['p_targ']
return y * z
x = 11
params = {'q_targ': 13, 'p_targ': 7}
f(params, x) # shape: ()
# vectorize over q_targ_list
vf = jax.vmap(f, ({'q_targ': 0, 'p_targ': None}, None), 0, 'q_targ')
vparams = {'q_targ': np.array([13, 17, 19]), 'p_targ': 5}
vf(vparams, x) # shape: (len(q_targ_list),)
# vectorize over pi_targ_list
vvf = jax.vmap(vf, ({'q_targ': None, 'p_targ': 0}, None), 0, 'p_targ')
vvparams = {'q_targ': np.array([13, 17, 19]), 'p_targ': np.array([5, 7])}
vvf(vvparams, x) # shape: (len(pi_targ_list), len(q_targ_list))
As for the state and params, we'll need to stack them separately, e.g.
@property
def target_params(self):
return hk.data_structures.to_immutable_dict({
'q': self.q.params,
'q_targ': jax.tree_multimap(
lambda *xs: jnp.stack(xs, axis=0), *(q.params for q in self.q_targ_list)),
'pi_targ': jax.tree_multimap(
lambda *xs: jnp.stack(xs, axis=0), *(pi.params for pi in self.pi_targ_list)),
})
@property
def target_state(self):
... # same idea
The thing that isn't great about this (performance-wise) is the jnp.stack
outside of a jax.jit
call.
Thank you for taking the time to show me this example!
Here's a toy example of what I had in mind:
Yes, that looks more reasonable than my hacky approach. I got it to work now, but have to refactor it ...obviously. But the general structure is now visible, the tests should pass and the dsac example is learning.
It even seems that the code now is a bit faster than before. For SAC (simpler network) I got around to around 11 ms
and now its at around 8ms
. I wonder why this is the case...
(coax) ~/projects/coax on ilmarinen.tnt.uni-hannover.de
❯ xvfb-run python /home/schubert/projects/coax/doc/examples/pendulum/dsac.py
[dsac|absl|INFO] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
[dsac|absl|INFO] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
[dsac|TrainMonitor|INFO] ep: 1, T: 201, G: -1.53e+03, avg_r: -7.65, avg_G: -1.53e+03, t: 200, dt: 7.920ms
[dsac|TrainMonitor|INFO] ep: 2, T: 402, G: -979, avg_r: -4.9, avg_G: -1.25e+03, t: 200, dt: 4.962ms
[dsac|TrainMonitor|INFO] ep: 3, T: 603, G: -758, avg_r: -3.79, avg_G: -1.09e+03, t: 200, dt: 5.308ms
[dsac|TrainMonitor|INFO] ep: 4, T: 804, G: -1.66e+03, avg_r: -8.31, avg_G: -1.23e+03, t: 200, dt: 4.933ms
[dsac|TrainMonitor|INFO] ep: 5, T: 1,005, G: -1.41e+03, avg_r: -7.07, avg_G: -1.27e+03, t: 200, dt: 4.955ms
[dsac|TrainMonitor|INFO] ep: 6, T: 1,206, G: -888, avg_r: -4.44, avg_G: -1.2e+03, t: 200, dt: 4.949ms
[dsac|TrainMonitor|INFO] ep: 7, T: 1,407, G: -699, avg_r: -3.49, avg_G: -1.13e+03, t: 200, dt: 4.903ms
[dsac|TrainMonitor|INFO] ep: 8, T: 1,608, G: -895, avg_r: -4.47, avg_G: -1.1e+03, t: 200, dt: 4.869ms
[dsac|TrainMonitor|INFO] ep: 9, T: 1,809, G: -1.69e+03, avg_r: -8.43, avg_G: -1.17e+03, t: 200, dt: 4.966ms
[dsac|TrainMonitor|INFO] ep: 10, T: 2,010, G: -1.8e+03, avg_r: -8.98, avg_G: -1.23e+03, t: 200, dt: 4.993ms
[dsac|TrainMonitor|INFO] ep: 11, T: 2,211, G: -1.47e+03, avg_r: -7.34, avg_G: -1.25e+03, t: 200, dt: 5.066ms
[dsac|TrainMonitor|INFO] ep: 12, T: 2,412, G: -1.48e+03, avg_r: -7.4, avg_G: -1.28e+03, t: 200, dt: 4.989ms
[dsac|TrainMonitor|INFO] ep: 13, T: 2,613, G: -1.19e+03, avg_r: -5.94, avg_G: -1.27e+03, t: 200, dt: 4.994ms
[dsac|TrainMonitor|INFO] ep: 14, T: 2,814, G: -1.36e+03, avg_r: -6.8, avg_G: -1.28e+03, t: 200, dt: 5.084ms
[dsac|TrainMonitor|INFO] ep: 15, T: 3,015, G: -1.38e+03, avg_r: -6.9, avg_G: -1.29e+03, t: 200, dt: 4.978ms
[dsac|TrainMonitor|INFO] ep: 16, T: 3,216, G: -971, avg_r: -4.86, avg_G: -1.26e+03, t: 200, dt: 4.898ms
[dsac|TrainMonitor|INFO] ep: 17, T: 3,417, G: -1.06e+03, avg_r: -5.31, avg_G: -1.24e+03, t: 200, dt: 4.955ms
[dsac|TrainMonitor|INFO] ep: 18, T: 3,618, G: -759, avg_r: -3.79, avg_G: -1.19e+03, t: 200, dt: 5.112ms
[dsac|TrainMonitor|INFO] ep: 19, T: 3,819, G: -999, avg_r: -4.99, avg_G: -1.17e+03, t: 200, dt: 5.148ms
[dsac|TrainMonitor|INFO] ep: 20, T: 4,020, G: -1.66e+03, avg_r: -8.3, avg_G: -1.22e+03, t: 200, dt: 4.893ms
[dsac|TrainMonitor|INFO] ep: 21, T: 4,221, G: -1.08e+03, avg_r: -5.41, avg_G: -1.21e+03, t: 200, dt: 4.952ms
[dsac|TrainMonitor|INFO] ep: 22, T: 4,422, G: -1.46e+03, avg_r: -7.3, avg_G: -1.23e+03, t: 200, dt: 4.868ms
[dsac|TrainMonitor|INFO] ep: 23, T: 4,623, G: -990, avg_r: -4.95, avg_G: -1.21e+03, t: 200, dt: 5.005ms
[dsac|TrainMonitor|INFO] ep: 24, T: 4,824, G: -1.78e+03, avg_r: -8.92, avg_G: -1.26e+03, t: 200, dt: 5.118ms
[dsac|TrainMonitor|INFO] ep: 25, T: 5,025, G: -916, avg_r: -4.58, avg_G: -1.23e+03, t: 200, dt: 42.785ms, SoftClippedDoubleQLearning/grads_max: 2.43, SoftClippedDoubleQLearning/grads_norm: 3.21, SoftClippedDoubleQLearning/loss: 3.93, SoftClippedDoubleQLearning/td_error: -4.25, SoftClippedDoubleQLearning/td_error_targ: 0
[dsac|TrainMonitor|INFO] ep: 26, T: 5,226, G: -1.21e+03, avg_r: -6.03, avg_G: -1.23e+03, t: 200, dt: 47.650ms, SoftClippedDoubleQLearning/grads_max: 1.8, SoftClippedDoubleQLearning/grads_norm: 2.51, SoftClippedDoubleQLearning/loss: 3.42, SoftClippedDoubleQLearning/td_error: -4.03, SoftClippedDoubleQLearning/td_error_targ: -0.171
[dsac|TrainMonitor|INFO] ep: 27, T: 5,427, G: -889, avg_r: -4.44, avg_G: -1.19e+03, t: 200, dt: 15.969ms, SoftClippedDoubleQLearning/grads_max: 1.37, SoftClippedDoubleQLearning/grads_norm: 2.15, SoftClippedDoubleQLearning/loss: 2.49, SoftClippedDoubleQLearning/td_error: -3.31, SoftClippedDoubleQLearning/td_error_targ: -0.899
[dsac|TrainMonitor|INFO] ep: 28, T: 5,628, G: -1.26e+03, avg_r: -6.29, avg_G: -1.2e+03, t: 200, dt: 16.326ms, SoftClippedDoubleQLearning/grads_max: 1.04, SoftClippedDoubleQLearning/grads_norm: 1.86, SoftClippedDoubleQLearning/loss: 1.8, SoftClippedDoubleQLearning/td_error: -2.46, SoftClippedDoubleQLearning/td_error_targ: -1.71
[dsac|TrainMonitor|INFO] ep: 29, T: 5,829, G: -1.07e+03, avg_r: -5.34, avg_G: -1.19e+03, t: 200, dt: 16.240ms, SoftClippedDoubleQLearning/grads_max: 0.79, SoftClippedDoubleQLearning/grads_norm: 1.58, SoftClippedDoubleQLearning/loss: 1.34, SoftClippedDoubleQLearning/td_error: -1.62, SoftClippedDoubleQLearning/td_error_targ: -2.34
[dsac|TrainMonitor|INFO] ep: 30, T: 6,030, G: -936, avg_r: -4.68, avg_G: -1.16e+03, t: 200, dt: 16.112ms, SoftClippedDoubleQLearning/grads_max: 0.628, SoftClippedDoubleQLearning/grads_norm: 1.36, SoftClippedDoubleQLearning/loss: 1.08, SoftClippedDoubleQLearning/td_error: -0.987, SoftClippedDoubleQLearning/td_error_targ: -2.65
[dsac|TrainMonitor|INFO] ep: 31, T: 6,231, G: -1.71e+03, avg_r: -8.56, avg_G: -1.22e+03, t: 200, dt: 15.815ms, SoftClippedDoubleQLearning/grads_max: 0.535, SoftClippedDoubleQLearning/grads_norm: 1.22, SoftClippedDoubleQLearning/loss: 0.941, SoftClippedDoubleQLearning/td_error: -0.611, SoftClippedDoubleQLearning/td_error_targ: -2.67
[dsac|TrainMonitor|INFO] ep: 32, T: 6,432, G: -1.64e+03, avg_r: -8.18, avg_G: -1.26e+03, t: 200, dt: 16.300ms, SoftClippedDoubleQLearning/grads_max: 0.478, SoftClippedDoubleQLearning/grads_norm: 1.11, SoftClippedDoubleQLearning/loss: 0.863, SoftClippedDoubleQLearning/td_error: -0.367, SoftClippedDoubleQLearning/td_error_targ: -2.49
[dsac|TrainMonitor|INFO] ep: 33, T: 6,633, G: -1.23e+03, avg_r: -6.16, avg_G: -1.26e+03, t: 200, dt: 16.151ms, SoftClippedDoubleQLearning/grads_max: 0.444, SoftClippedDoubleQLearning/grads_norm: 1.05, SoftClippedDoubleQLearning/loss: 0.814, SoftClippedDoubleQLearning/td_error: -0.237, SoftClippedDoubleQLearning/td_error_targ: -2.2
[dsac|TrainMonitor|INFO] ep: 34, T: 6,834, G: -879, avg_r: -4.4, avg_G: -1.22e+03, t: 200, dt: 16.293ms, SoftClippedDoubleQLearning/grads_max: 0.423, SoftClippedDoubleQLearning/grads_norm: 1.01, SoftClippedDoubleQLearning/loss: 0.793, SoftClippedDoubleQLearning/td_error: -0.164, SoftClippedDoubleQLearning/td_error_targ: -1.88
[dsac|TrainMonitor|INFO] ep: 35, T: 7,035, G: -768, avg_r: -3.84, avg_G: -1.17e+03, t: 200, dt: 16.244ms, SoftClippedDoubleQLearning/grads_max: 0.438, SoftClippedDoubleQLearning/grads_norm: 1.05, SoftClippedDoubleQLearning/loss: 0.826, SoftClippedDoubleQLearning/td_error: -0.0973, SoftClippedDoubleQLearning/td_error_targ: -1.51
[dsac|TrainMonitor|INFO] ep: 36, T: 7,236, G: -1.43e+03, avg_r: -7.17, avg_G: -1.2e+03, t: 200, dt: 15.914ms, SoftClippedDoubleQLearning/grads_max: 0.457, SoftClippedDoubleQLearning/grads_norm: 1.09, SoftClippedDoubleQLearning/loss: 0.84, SoftClippedDoubleQLearning/td_error: -0.0822, SoftClippedDoubleQLearning/td_error_targ: -1.21
[dsac|TrainMonitor|INFO] ep: 37, T: 7,437, G: -1.15e+03, avg_r: -5.75, avg_G: -1.19e+03, t: 200, dt: 15.901ms, SoftClippedDoubleQLearning/grads_max: 0.435, SoftClippedDoubleQLearning/grads_norm: 1.04, SoftClippedDoubleQLearning/loss: 0.796, SoftClippedDoubleQLearning/td_error: -0.0736, SoftClippedDoubleQLearning/td_error_targ: -0.953
[dsac|TrainMonitor|INFO] ep: 38, T: 7,638, G: -879, avg_r: -4.39, avg_G: -1.16e+03, t: 200, dt: 30.564ms, SoftClippedDoubleQLearning/grads_max: 0.428, SoftClippedDoubleQLearning/grads_norm: 1.04, SoftClippedDoubleQLearning/loss: 0.788, SoftClippedDoubleQLearning/td_error: -0.0607, SoftClippedDoubleQLearning/td_error_targ: -0.759, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 1.15, SoftPG/grads_max: 0.27, SoftPG/grads_norm: 0.432, SoftPG/kl_div_old: 1.4, SoftPG/loss: 8.54, SoftPG/loss_bare: 8.77
[dsac|TrainMonitor|INFO] ep: 39, T: 7,839, G: -1.76e+03, avg_r: -8.8, avg_G: -1.22e+03, t: 200, dt: 17.507ms, SoftClippedDoubleQLearning/grads_max: 0.432, SoftClippedDoubleQLearning/grads_norm: 1.05, SoftClippedDoubleQLearning/loss: 0.798, SoftClippedDoubleQLearning/td_error: -0.0452, SoftClippedDoubleQLearning/td_error_targ: -0.593, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 1.02, SoftPG/grads_max: 0.407, SoftPG/grads_norm: 0.783, SoftPG/kl_div_old: 1.25, SoftPG/loss: 8.65, SoftPG/loss_bare: 8.86
[dsac|TrainMonitor|INFO] ep: 40, T: 8,040, G: -1.65e+03, avg_r: -8.25, avg_G: -1.27e+03, t: 200, dt: 17.767ms, SoftClippedDoubleQLearning/grads_max: 0.43, SoftClippedDoubleQLearning/grads_norm: 1.05, SoftClippedDoubleQLearning/loss: 0.789, SoftClippedDoubleQLearning/td_error: -0.0393, SoftClippedDoubleQLearning/td_error_targ: -0.472, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.548, SoftPG/grads_max: 0.629, SoftPG/grads_norm: 1.19, SoftPG/kl_div_old: 0.677, SoftPG/loss: 8.62, SoftPG/loss_bare: 8.73
[dsac|TrainMonitor|INFO] ep: 41, T: 8,241, G: -1.64e+03, avg_r: -8.22, avg_G: -1.3e+03, t: 200, dt: 17.708ms, SoftClippedDoubleQLearning/grads_max: 0.413, SoftClippedDoubleQLearning/grads_norm: 1.02, SoftClippedDoubleQLearning/loss: 0.77, SoftClippedDoubleQLearning/td_error: -0.0282, SoftClippedDoubleQLearning/td_error_targ: -0.388, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.154, SoftPG/grads_max: 0.651, SoftPG/grads_norm: 1.14, SoftPG/kl_div_old: 0.186, SoftPG/loss: 8.7, SoftPG/loss_bare: 8.73
[dsac|TrainMonitor|INFO] ep: 42, T: 8,442, G: -799, avg_r: -4, avg_G: -1.25e+03, t: 200, dt: 18.241ms, SoftClippedDoubleQLearning/grads_max: 0.401, SoftClippedDoubleQLearning/grads_norm: 0.999, SoftClippedDoubleQLearning/loss: 0.769, SoftClippedDoubleQLearning/td_error: -0.0266, SoftClippedDoubleQLearning/td_error_targ: -0.311, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.129, SoftPG/grads_max: 0.449, SoftPG/grads_norm: 0.772, SoftPG/kl_div_old: 0.147, SoftPG/loss: 8.83, SoftPG/loss_bare: 8.85
[dsac|TrainMonitor|INFO] ep: 43, T: 8,643, G: -1.17e+03, avg_r: -5.84, avg_G: -1.24e+03, t: 200, dt: 17.834ms, SoftClippedDoubleQLearning/grads_max: 0.402, SoftClippedDoubleQLearning/grads_norm: 1.01, SoftClippedDoubleQLearning/loss: 0.77, SoftClippedDoubleQLearning/td_error: -0.0296, SoftClippedDoubleQLearning/td_error_targ: -0.251, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.195, SoftPG/grads_max: 0.487, SoftPG/grads_norm: 0.84, SoftPG/kl_div_old: 0.24, SoftPG/loss: 8.95, SoftPG/loss_bare: 8.99
[dsac|TrainMonitor|INFO] ep: 44, T: 8,844, G: -1.37e+03, avg_r: -6.87, avg_G: -1.26e+03, t: 200, dt: 18.037ms, SoftClippedDoubleQLearning/grads_max: 0.393, SoftClippedDoubleQLearning/grads_norm: 0.994, SoftClippedDoubleQLearning/loss: 0.759, SoftClippedDoubleQLearning/td_error: -0.0255, SoftClippedDoubleQLearning/td_error_targ: -0.208, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.243, SoftPG/grads_max: 0.468, SoftPG/grads_norm: 0.84, SoftPG/kl_div_old: 0.309, SoftPG/loss: 9.04, SoftPG/loss_bare: 9.09
[dsac|TrainMonitor|INFO] ep: 45, T: 9,045, G: -857, avg_r: -4.29, avg_G: -1.22e+03, t: 200, dt: 18.677ms, SoftClippedDoubleQLearning/grads_max: 0.395, SoftClippedDoubleQLearning/grads_norm: 1, SoftClippedDoubleQLearning/loss: 0.761, SoftClippedDoubleQLearning/td_error: -0.0199, SoftClippedDoubleQLearning/td_error_targ: -0.173, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.306, SoftPG/grads_max: 0.362, SoftPG/grads_norm: 0.63, SoftPG/kl_div_old: 0.386, SoftPG/loss: 9.11, SoftPG/loss_bare: 9.17
[dsac|TrainMonitor|INFO] ep: 46, T: 9,246, G: -1.66e+03, avg_r: -8.29, avg_G: -1.26e+03, t: 200, dt: 18.522ms, SoftClippedDoubleQLearning/grads_max: 0.401, SoftClippedDoubleQLearning/grads_norm: 1.02, SoftClippedDoubleQLearning/loss: 0.778, SoftClippedDoubleQLearning/td_error: -0.0136, SoftClippedDoubleQLearning/td_error_targ: -0.14, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.342, SoftPG/grads_max: 0.293, SoftPG/grads_norm: 0.491, SoftPG/kl_div_old: 0.439, SoftPG/loss: 9.17, SoftPG/loss_bare: 9.24
[dsac|TrainMonitor|INFO] ep: 47, T: 9,447, G: -1.69e+03, avg_r: -8.45, avg_G: -1.3e+03, t: 200, dt: 18.874ms, SoftClippedDoubleQLearning/grads_max: 0.392, SoftClippedDoubleQLearning/grads_norm: 1, SoftClippedDoubleQLearning/loss: 0.757, SoftClippedDoubleQLearning/td_error: -0.0188, SoftClippedDoubleQLearning/td_error_targ: -0.116, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.421, SoftPG/grads_max: 0.297, SoftPG/grads_norm: 0.51, SoftPG/kl_div_old: 0.525, SoftPG/loss: 9.21, SoftPG/loss_bare: 9.29
[dsac|TrainMonitor|INFO] ep: 48, T: 9,648, G: -1.49e+03, avg_r: -7.46, avg_G: -1.32e+03, t: 200, dt: 18.156ms, SoftClippedDoubleQLearning/grads_max: 0.392, SoftClippedDoubleQLearning/grads_norm: 1.01, SoftClippedDoubleQLearning/loss: 0.757, SoftClippedDoubleQLearning/td_error: -0.0144, SoftClippedDoubleQLearning/td_error_targ: -0.0985, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.48, SoftPG/grads_max: 0.278, SoftPG/grads_norm: 0.5, SoftPG/kl_div_old: 0.587, SoftPG/loss: 9.24, SoftPG/loss_bare: 9.33
[dsac|TrainMonitor|INFO] ep: 49, T: 9,849, G: -1.81e+03, avg_r: -9.06, avg_G: -1.37e+03, t: 200, dt: 17.650ms, SoftClippedDoubleQLearning/grads_max: 0.373, SoftClippedDoubleQLearning/grads_norm: 0.965, SoftClippedDoubleQLearning/loss: 0.719, SoftClippedDoubleQLearning/td_error: -0.0183, SoftClippedDoubleQLearning/td_error_targ: -0.092, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.555, SoftPG/grads_max: 0.255, SoftPG/grads_norm: 0.436, SoftPG/kl_div_old: 0.687, SoftPG/loss: 9.27, SoftPG/loss_bare: 9.38
[dsac|generate_gif|INFO] recorded episode to: ./data/gifs/dsac/T00010000.gif
[dsac|TrainMonitor|INFO] ep: 50, T: 10,050, G: -625, avg_r: -3.12, avg_G: -1.3e+03, t: 200, dt: 45.271ms, SoftClippedDoubleQLearning/grads_max: 0.378, SoftClippedDoubleQLearning/grads_norm: 0.985, SoftClippedDoubleQLearning/loss: 0.749, SoftClippedDoubleQLearning/td_error: 0.000823, SoftClippedDoubleQLearning/td_error_targ: -0.0778, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.593, SoftPG/grads_max: 0.23, SoftPG/grads_norm: 0.397, SoftPG/kl_div_old: 0.717, SoftPG/loss: 9.29, SoftPG/loss_bare: 9.41
[dsac|TrainMonitor|INFO] ep: 51, T: 10,251, G: -1.39e+03, avg_r: -6.96, avg_G: -1.31e+03, t: 200, dt: 17.901ms, SoftClippedDoubleQLearning/grads_max: 0.386, SoftClippedDoubleQLearning/grads_norm: 1.01, SoftClippedDoubleQLearning/loss: 0.763, SoftClippedDoubleQLearning/td_error: -0.00187, SoftClippedDoubleQLearning/td_error_targ: -0.0488, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.646, SoftPG/grads_max: 0.201, SoftPG/grads_norm: 0.352, SoftPG/kl_div_old: 0.796, SoftPG/loss: 9.31, SoftPG/loss_bare: 9.44
[dsac|TrainMonitor|INFO] ep: 52, T: 10,452, G: -1.43e+03, avg_r: -7.16, avg_G: -1.32e+03, t: 200, dt: 17.584ms, SoftClippedDoubleQLearning/grads_max: 0.374, SoftClippedDoubleQLearning/grads_norm: 0.981, SoftClippedDoubleQLearning/loss: 0.738, SoftClippedDoubleQLearning/td_error: -0.00105, SoftClippedDoubleQLearning/td_error_targ: -0.0499, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.701, SoftPG/grads_max: 0.18, SoftPG/grads_norm: 0.3, SoftPG/kl_div_old: 0.857, SoftPG/loss: 9.32, SoftPG/loss_bare: 9.46
[dsac|TrainMonitor|INFO] ep: 53, T: 10,653, G: -860, avg_r: -4.3, avg_G: -1.27e+03, t: 200, dt: 17.879ms, SoftClippedDoubleQLearning/grads_max: 0.376, SoftClippedDoubleQLearning/grads_norm: 0.99, SoftClippedDoubleQLearning/loss: 0.752, SoftClippedDoubleQLearning/td_error: -0.000538, SoftClippedDoubleQLearning/td_error_targ: -0.0384, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.742, SoftPG/grads_max: 0.144, SoftPG/grads_norm: 0.254, SoftPG/kl_div_old: 0.913, SoftPG/loss: 9.33, SoftPG/loss_bare: 9.48
[dsac|TrainMonitor|INFO] ep: 54, T: 10,854, G: -759, avg_r: -3.79, avg_G: -1.22e+03, t: 200, dt: 17.867ms, SoftClippedDoubleQLearning/grads_max: 0.384, SoftClippedDoubleQLearning/grads_norm: 1.01, SoftClippedDoubleQLearning/loss: 0.764, SoftClippedDoubleQLearning/td_error: -0.00205, SoftClippedDoubleQLearning/td_error_targ: -0.0196, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.771, SoftPG/grads_max: 0.159, SoftPG/grads_norm: 0.285, SoftPG/kl_div_old: 0.931, SoftPG/loss: 9.34, SoftPG/loss_bare: 9.49
[dsac|TrainMonitor|INFO] ep: 55, T: 11,055, G: -1.19e+03, avg_r: -5.94, avg_G: -1.22e+03, t: 200, dt: 17.724ms, SoftClippedDoubleQLearning/grads_max: 0.383, SoftClippedDoubleQLearning/grads_norm: 1.01, SoftClippedDoubleQLearning/loss: 0.77, SoftClippedDoubleQLearning/td_error: 0.00872, SoftClippedDoubleQLearning/td_error_targ: -0.0199, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.796, SoftPG/grads_max: 0.122, SoftPG/grads_norm: 0.194, SoftPG/kl_div_old: 0.958, SoftPG/loss: 9.35, SoftPG/loss_bare: 9.5
[dsac|TrainMonitor|INFO] ep: 56, T: 11,256, G: -1.61e+03, avg_r: -8.04, avg_G: -1.26e+03, t: 200, dt: 17.715ms, SoftClippedDoubleQLearning/grads_max: 0.391, SoftClippedDoubleQLearning/grads_norm: 1.02, SoftClippedDoubleQLearning/loss: 0.774, SoftClippedDoubleQLearning/td_error: -0.000953, SoftClippedDoubleQLearning/td_error_targ: -0.0106, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.812, SoftPG/grads_max: 0.138, SoftPG/grads_norm: 0.243, SoftPG/kl_div_old: 0.992, SoftPG/loss: 9.35, SoftPG/loss_bare: 9.51
[dsac|TrainMonitor|INFO] ep: 57, T: 11,457, G: -747, avg_r: -3.74, avg_G: -1.21e+03, t: 200, dt: 17.830ms, SoftClippedDoubleQLearning/grads_max: 0.38, SoftClippedDoubleQLearning/grads_norm: 1, SoftClippedDoubleQLearning/loss: 0.755, SoftClippedDoubleQLearning/td_error: -0.00426, SoftClippedDoubleQLearning/td_error_targ: -0.02, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.832, SoftPG/grads_max: 0.148, SoftPG/grads_norm: 0.258, SoftPG/kl_div_old: 1.01, SoftPG/loss: 9.35, SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 58, T: 11,658, G: -765, avg_r: -3.83, avg_G: -1.16e+03, t: 200, dt: 17.787ms, SoftClippedDoubleQLearning/grads_max: 0.401, SoftClippedDoubleQLearning/grads_norm: 1.05, SoftClippedDoubleQLearning/loss: 0.808, SoftClippedDoubleQLearning/td_error: 0.0121, SoftClippedDoubleQLearning/td_error_targ: 0.00259, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.862, SoftPG/grads_max: 0.126, SoftPG/grads_norm: 0.219, SoftPG/kl_div_old: 1.05, SoftPG/loss: 9.35, SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 59, T: 11,859, G: -1.27e+03, avg_r: -6.35, avg_G: -1.17e+03, t: 200, dt: 17.920ms, SoftClippedDoubleQLearning/grads_max: 0.402, SoftClippedDoubleQLearning/grads_norm: 1.05, SoftClippedDoubleQLearning/loss: 0.789, SoftClippedDoubleQLearning/td_error: -0.018, SoftClippedDoubleQLearning/td_error_targ: 0.00638, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.874, SoftPG/grads_max: 0.114, SoftPG/grads_norm: 0.202, SoftPG/kl_div_old: 1.06, SoftPG/loss: 9.35, SoftPG/loss_bare: 9.53
[dsac|TrainMonitor|INFO] ep: 60, T: 12,060, G: -1.01e+03, avg_r: -5.05, avg_G: -1.16e+03, t: 200, dt: 17.924ms, SoftClippedDoubleQLearning/grads_max: 0.4, SoftClippedDoubleQLearning/grads_norm: 1.05, SoftClippedDoubleQLearning/loss: 0.798, SoftClippedDoubleQLearning/td_error: 0.00773, SoftClippedDoubleQLearning/td_error_targ: 9.54e-05, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.87, SoftPG/grads_max: 0.125, SoftPG/grads_norm: 0.216, SoftPG/kl_div_old: 1.07, SoftPG/loss: 9.36, SoftPG/loss_bare: 9.53
[dsac|TrainMonitor|INFO] ep: 61, T: 12,261, G: -1.55e+03, avg_r: -7.74, avg_G: -1.2e+03, t: 200, dt: 18.050ms, SoftClippedDoubleQLearning/grads_max: 0.395, SoftClippedDoubleQLearning/grads_norm: 1.03, SoftClippedDoubleQLearning/loss: 0.788, SoftClippedDoubleQLearning/td_error: 0.00251, SoftClippedDoubleQLearning/td_error_targ: -2.04e-05, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.871, SoftPG/grads_max: 0.133, SoftPG/grads_norm: 0.223, SoftPG/kl_div_old: 1.07, SoftPG/loss: 9.35, SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 62, T: 12,462, G: -856, avg_r: -4.28, avg_G: -1.16e+03, t: 200, dt: 17.947ms, SoftClippedDoubleQLearning/grads_max: 0.393, SoftClippedDoubleQLearning/grads_norm: 1.03, SoftClippedDoubleQLearning/loss: 0.794, SoftClippedDoubleQLearning/td_error: 0.000313, SoftClippedDoubleQLearning/td_error_targ: -0.00501, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.864, SoftPG/grads_max: 0.118, SoftPG/grads_norm: 0.204, SoftPG/kl_div_old: 1.06, SoftPG/loss: 9.35, SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 63, T: 12,663, G: -969, avg_r: -4.84, avg_G: -1.14e+03, t: 200, dt: 17.778ms, SoftClippedDoubleQLearning/grads_max: 0.401, SoftClippedDoubleQLearning/grads_norm: 1.05, SoftClippedDoubleQLearning/loss: 0.806, SoftClippedDoubleQLearning/td_error: -0.0033, SoftClippedDoubleQLearning/td_error_targ: 0.00958, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.872, SoftPG/grads_max: 0.122, SoftPG/grads_norm: 0.216, SoftPG/kl_div_old: 1.08, SoftPG/loss: 9.35, SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 64, T: 12,864, G: -1.24e+03, avg_r: -6.2, avg_G: -1.15e+03, t: 200, dt: 17.919ms, SoftClippedDoubleQLearning/grads_max: 0.389, SoftClippedDoubleQLearning/grads_norm: 1.03, SoftClippedDoubleQLearning/loss: 0.795, SoftClippedDoubleQLearning/td_error: 0.00509, SoftClippedDoubleQLearning/td_error_targ: -0.00353, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.861, SoftPG/grads_max: 0.142, SoftPG/grads_norm: 0.238, SoftPG/kl_div_old: 1.05, SoftPG/loss: 9.35, SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 65, T: 13,065, G: -879, avg_r: -4.39, avg_G: -1.12e+03, t: 200, dt: 17.753ms, SoftClippedDoubleQLearning/grads_max: 0.399, SoftClippedDoubleQLearning/grads_norm: 1.04, SoftClippedDoubleQLearning/loss: 0.807, SoftClippedDoubleQLearning/td_error: -0.000263, SoftClippedDoubleQLearning/td_error_targ: 0.0021, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.88, SoftPG/grads_max: 0.12, SoftPG/grads_norm: 0.206, SoftPG/kl_div_old: 1.08, SoftPG/loss: 9.35, SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 66, T: 13,266, G: -1.87e+03, avg_r: -9.35, avg_G: -1.2e+03, t: 200, dt: 18.146ms, SoftClippedDoubleQLearning/grads_max: 0.412, SoftClippedDoubleQLearning/grads_norm: 1.07, SoftClippedDoubleQLearning/loss: 0.811, SoftClippedDoubleQLearning/td_error: -0.0122, SoftClippedDoubleQLearning/td_error_targ: 0.00878, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.874, SoftPG/grads_max: 0.122, SoftPG/grads_norm: 0.223, SoftPG/kl_div_old: 1.07, SoftPG/loss: 9.35, SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 67, T: 13,467, G: -968, avg_r: -4.84, avg_G: -1.18e+03, t: 200, dt: 17.884ms, SoftClippedDoubleQLearning/grads_max: 0.404, SoftClippedDoubleQLearning/grads_norm: 1.05, SoftClippedDoubleQLearning/loss: 0.808, SoftClippedDoubleQLearning/td_error: 0.00291, SoftClippedDoubleQLearning/td_error_targ: -0.00665, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.872, SoftPG/grads_max: 0.129, SoftPG/grads_norm: 0.225, SoftPG/kl_div_old: 1.06, SoftPG/loss: 9.35, SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 68, T: 13,668, G: -626, avg_r: -3.13, avg_G: -1.12e+03, t: 200, dt: 17.906ms, SoftClippedDoubleQLearning/grads_max: 0.405, SoftClippedDoubleQLearning/grads_norm: 1.05, SoftClippedDoubleQLearning/loss: 0.807, SoftClippedDoubleQLearning/td_error: -0.000917, SoftClippedDoubleQLearning/td_error_targ: 0.00247, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.881, SoftPG/grads_max: 0.149, SoftPG/grads_norm: 0.272, SoftPG/kl_div_old: 1.08, SoftPG/loss: 9.35, SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 69, T: 13,869, G: -1.62e+03, avg_r: -8.1, avg_G: -1.17e+03, t: 200, dt: 17.661ms, SoftClippedDoubleQLearning/grads_max: 0.423, SoftClippedDoubleQLearning/grads_norm: 1.1, SoftClippedDoubleQLearning/loss: 0.841, SoftClippedDoubleQLearning/td_error: 0.00232, SoftClippedDoubleQLearning/td_error_targ: 0.0128, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.885, SoftPG/grads_max: 0.155, SoftPG/grads_norm: 0.282, SoftPG/kl_div_old: 1.07, SoftPG/loss: 9.35, SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 70, T: 14,070, G: -1.33e+03, avg_r: -6.63, avg_G: -1.19e+03, t: 200, dt: 17.895ms, SoftClippedDoubleQLearning/grads_max: 0.412, SoftClippedDoubleQLearning/grads_norm: 1.07, SoftClippedDoubleQLearning/loss: 0.816, SoftClippedDoubleQLearning/td_error: -0.00507, SoftClippedDoubleQLearning/td_error_targ: 0.00611, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.885, SoftPG/grads_max: 0.153, SoftPG/grads_norm: 0.277, SoftPG/kl_div_old: 1.09, SoftPG/loss: 9.35, SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 71, T: 14,271, G: -858, avg_r: -4.29, avg_G: -1.15e+03, t: 200, dt: 17.890ms, SoftClippedDoubleQLearning/grads_max: 0.417, SoftClippedDoubleQLearning/grads_norm: 1.08, SoftClippedDoubleQLearning/loss: 0.837, SoftClippedDoubleQLearning/td_error: -0.00151, SoftClippedDoubleQLearning/td_error_targ: 0.00969, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.89, SoftPG/grads_max: 0.143, SoftPG/grads_norm: 0.243, SoftPG/kl_div_old: 1.08, SoftPG/loss: 9.34, SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 72, T: 14,472, G: -1.64e+03, avg_r: -8.18, avg_G: -1.2e+03, t: 200, dt: 17.939ms, SoftClippedDoubleQLearning/grads_max: 0.418, SoftClippedDoubleQLearning/grads_norm: 1.09, SoftClippedDoubleQLearning/loss: 0.828, SoftClippedDoubleQLearning/td_error: -0.00188, SoftClippedDoubleQLearning/td_error_targ: 0.00845, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.916, SoftPG/grads_max: 0.126, SoftPG/grads_norm: 0.234, SoftPG/kl_div_old: 1.12, SoftPG/loss: 9.34, SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 73, T: 14,673, G: -1.17e+03, avg_r: -5.85, avg_G: -1.2e+03, t: 200, dt: 17.870ms, SoftClippedDoubleQLearning/grads_max: 0.402, SoftClippedDoubleQLearning/grads_norm: 1.04, SoftClippedDoubleQLearning/loss: 0.798, SoftClippedDoubleQLearning/td_error: -0.00402, SoftClippedDoubleQLearning/td_error_targ: -0.00559, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.9, SoftPG/grads_max: 0.159, SoftPG/grads_norm: 0.282, SoftPG/kl_div_old: 1.09, SoftPG/loss: 9.34, SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 74, T: 14,874, G: -1.49e+03, avg_r: -7.45, avg_G: -1.23e+03, t: 200, dt: 17.755ms, SoftClippedDoubleQLearning/grads_max: 0.401, SoftClippedDoubleQLearning/grads_norm: 1.04, SoftClippedDoubleQLearning/loss: 0.79, SoftClippedDoubleQLearning/td_error: -0.00395, SoftClippedDoubleQLearning/td_error_targ: -0.00844, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.919, SoftPG/grads_max: 0.168, SoftPG/grads_norm: 0.298, SoftPG/kl_div_old: 1.11, SoftPG/loss: 9.34, SoftPG/loss_bare: 9.52
[dsac|TrainMonitor|INFO] ep: 75, T: 15,075, G: -1.8e+03, avg_r: -9.02, avg_G: -1.29e+03, t: 200, dt: 17.804ms, SoftClippedDoubleQLearning/grads_max: 0.397, SoftClippedDoubleQLearning/grads_norm: 1.03, SoftClippedDoubleQLearning/loss: 0.79, SoftClippedDoubleQLearning/td_error: 0.00114, SoftClippedDoubleQLearning/td_error_targ: -0.0144, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.934, SoftPG/grads_max: 0.177, SoftPG/grads_norm: 0.306, SoftPG/kl_div_old: 1.16, SoftPG/loss: 9.34, SoftPG/loss_bare: 9.53
[dsac|TrainMonitor|INFO] ep: 76, T: 15,276, G: -1.37e+03, avg_r: -6.83, avg_G: -1.29e+03, t: 200, dt: 17.985ms, SoftClippedDoubleQLearning/grads_max: 0.393, SoftClippedDoubleQLearning/grads_norm: 1.02, SoftClippedDoubleQLearning/loss: 0.783, SoftClippedDoubleQLearning/td_error: -0.00264, SoftClippedDoubleQLearning/td_error_targ: -0.0129, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 0.956, SoftPG/grads_max: 0.169, SoftPG/grads_norm: 0.31, SoftPG/kl_div_old: 1.16, SoftPG/loss: 9.34, SoftPG/loss_bare: 9.53
[dsac|TrainMonitor|INFO] ep: 77, T: 15,477, G: -743, avg_r: -3.72, avg_G: -1.24e+03, t: 200, dt: 17.934ms, SoftClippedDoubleQLearning/grads_max: 0.399, SoftClippedDoubleQLearning/grads_norm: 1.04, SoftClippedDoubleQLearning/loss: 0.794, SoftClippedDoubleQLearning/td_error: 0.00778, SoftClippedDoubleQLearning/td_error_targ: -0.0117, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 1.01, SoftPG/grads_max: 0.166, SoftPG/grads_norm: 0.296, SoftPG/kl_div_old: 1.22, SoftPG/loss: 9.34, SoftPG/loss_bare: 9.54
[dsac|TrainMonitor|INFO] ep: 78, T: 15,678, G: -1.37e+03, avg_r: -6.83, avg_G: -1.25e+03, t: 200, dt: 17.847ms, SoftClippedDoubleQLearning/grads_max: 0.406, SoftClippedDoubleQLearning/grads_norm: 1.05, SoftClippedDoubleQLearning/loss: 0.794, SoftClippedDoubleQLearning/td_error: -0.000648, SoftClippedDoubleQLearning/td_error_targ: 0.00351, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 1.03, SoftPG/grads_max: 0.167, SoftPG/grads_norm: 0.308, SoftPG/kl_div_old: 1.24, SoftPG/loss: 9.34, SoftPG/loss_bare: 9.54
[dsac|TrainMonitor|INFO] ep: 79, T: 15,879, G: -1.48e+03, avg_r: -7.39, avg_G: -1.27e+03, t: 200, dt: 17.797ms, SoftClippedDoubleQLearning/grads_max: 0.398, SoftClippedDoubleQLearning/grads_norm: 1.03, SoftClippedDoubleQLearning/loss: 0.788, SoftClippedDoubleQLearning/td_error: -0.00115, SoftClippedDoubleQLearning/td_error_targ: -0.00406, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 1.02, SoftPG/grads_max: 0.141, SoftPG/grads_norm: 0.252, SoftPG/kl_div_old: 1.25, SoftPG/loss: 9.35, SoftPG/loss_bare: 9.55
[dsac|TrainMonitor|INFO] ep: 80, T: 16,080, G: -1.67e+03, avg_r: -8.35, avg_G: -1.31e+03, t: 200, dt: 17.926ms, SoftClippedDoubleQLearning/grads_max: 0.399, SoftClippedDoubleQLearning/grads_norm: 1.03, SoftClippedDoubleQLearning/loss: 0.777, SoftClippedDoubleQLearning/td_error: -0.00173, SoftClippedDoubleQLearning/td_error_targ: -0.00815, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 1.04, SoftPG/grads_max: 0.154, SoftPG/grads_norm: 0.274, SoftPG/kl_div_old: 1.27, SoftPG/loss: 9.35, SoftPG/loss_bare: 9.55
[dsac|TrainMonitor|INFO] ep: 81, T: 16,281, G: -860, avg_r: -4.3, avg_G: -1.27e+03, t: 200, dt: 17.890ms, SoftClippedDoubleQLearning/grads_max: 0.406, SoftClippedDoubleQLearning/grads_norm: 1.05, SoftClippedDoubleQLearning/loss: 0.789, SoftClippedDoubleQLearning/td_error: 0.00353, SoftClippedDoubleQLearning/td_error_targ: -0.00365, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 1.05, SoftPG/grads_max: 0.169, SoftPG/grads_norm: 0.305, SoftPG/kl_div_old: 1.27, SoftPG/loss: 9.34, SoftPG/loss_bare: 9.55
[dsac|TrainMonitor|INFO] ep: 82, T: 16,482, G: -1.07e+03, avg_r: -5.35, avg_G: -1.25e+03, t: 200, dt: 17.795ms, SoftClippedDoubleQLearning/grads_max: 0.41, SoftClippedDoubleQLearning/grads_norm: 1.05, SoftClippedDoubleQLearning/loss: 0.778, SoftClippedDoubleQLearning/td_error: -0.00925, SoftClippedDoubleQLearning/td_error_targ: -0.00354, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 1.09, SoftPG/grads_max: 0.16, SoftPG/grads_norm: 0.282, SoftPG/kl_div_old: 1.32, SoftPG/loss: 9.35, SoftPG/loss_bare: 9.56
[dsac|TrainMonitor|INFO] ep: 83, T: 16,683, G: -962, avg_r: -4.81, avg_G: -1.22e+03, t: 200, dt: 17.747ms, SoftClippedDoubleQLearning/grads_max: 0.409, SoftClippedDoubleQLearning/grads_norm: 1.05, SoftClippedDoubleQLearning/loss: 0.786, SoftClippedDoubleQLearning/td_error: -0.00545, SoftClippedDoubleQLearning/td_error_targ: -0.00268, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 1.11, SoftPG/grads_max: 0.142, SoftPG/grads_norm: 0.259, SoftPG/kl_div_old: 1.36, SoftPG/loss: 9.35, SoftPG/loss_bare: 9.57
[dsac|TrainMonitor|INFO] ep: 84, T: 16,884, G: -1.56e+03, avg_r: -7.79, avg_G: -1.25e+03, t: 200, dt: 17.797ms, SoftClippedDoubleQLearning/grads_max: 0.407, SoftClippedDoubleQLearning/grads_norm: 1.05, SoftClippedDoubleQLearning/loss: 0.78, SoftClippedDoubleQLearning/td_error: 0.00423, SoftClippedDoubleQLearning/td_error_targ: -0.00923, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 1.13, SoftPG/grads_max: 0.18, SoftPG/grads_norm: 0.324, SoftPG/kl_div_old: 1.36, SoftPG/loss: 9.35, SoftPG/loss_bare: 9.57
[dsac|TrainMonitor|INFO] ep: 85, T: 17,085, G: -1.3e+03, avg_r: -6.5, avg_G: -1.26e+03, t: 200, dt: 17.771ms, SoftClippedDoubleQLearning/grads_max: 0.4, SoftClippedDoubleQLearning/grads_norm: 1.03, SoftClippedDoubleQLearning/loss: 0.765, SoftClippedDoubleQLearning/td_error: -9.46e-05, SoftClippedDoubleQLearning/td_error_targ: -0.008, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 1.18, SoftPG/grads_max: 0.151, SoftPG/grads_norm: 0.282, SoftPG/kl_div_old: 1.45, SoftPG/loss: 9.35, SoftPG/loss_bare: 9.58
[dsac|TrainMonitor|INFO] ep: 86, T: 17,286, G: -1.49e+03, avg_r: -7.46, avg_G: -1.28e+03, t: 200, dt: 17.933ms, SoftClippedDoubleQLearning/grads_max: 0.401, SoftClippedDoubleQLearning/grads_norm: 1.03, SoftClippedDoubleQLearning/loss: 0.751, SoftClippedDoubleQLearning/td_error: -0.00512, SoftClippedDoubleQLearning/td_error_targ: -0.00857, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 1.56, SoftPG/grads_max: 0.132, SoftPG/grads_norm: 0.259, SoftPG/kl_div_old: 1.89, SoftPG/loss: 9.33, SoftPG/loss_bare: 9.64
[dsac|TrainMonitor|INFO] ep: 87, T: 17,487, G: -1.28e+03, avg_r: -6.41, avg_G: -1.28e+03, t: 200, dt: 17.907ms, SoftClippedDoubleQLearning/grads_max: 0.398, SoftClippedDoubleQLearning/grads_norm: 1.03, SoftClippedDoubleQLearning/loss: 0.731, SoftClippedDoubleQLearning/td_error: -0.00714, SoftClippedDoubleQLearning/td_error_targ: -0.0154, EntropyRegularizer/beta: 0.04, EntropyRegularizer/entropy: 4.84, SoftPG/grads_max: 0.5, SoftPG/grads_norm: 0.819, SoftPG/kl_div_old: 5.88, SoftPG/loss: 8.85, SoftPG/loss_bare: 9.81
Invalid nan value encountered in the output of a C++-jit/pmap function. Calling the de-optimized version.
Invalid value encountered in the output of a jit/pmap-ed function. Calling the de-optimized version.
Invalid value encountered in the output of a jit/pmap-ed function. Calling the de-optimized version.
Invalid value encountered in the output of a jit/pmap-ed function. Calling the de-optimized version.
Invalid value encountered in the output of a jit/pmap-ed function. Calling the de-optimized version.
Traceback (most recent call last):
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/api.py", line 131, in _nan_check_posthook
xla.check_special(xla.xla_call_p, buffers)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 378, in check_special
_check_special(name, buf.xla_shape(), buf)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 384, in _check_special
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
FloatingPointError: invalid value (nan) encountered in xla_call
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 584, in _xla_call_impl
out = compiled_fun(*args)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 978, in _execute_compiled
check_special(name, out_bufs)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 378, in check_special
_check_special(name, buf.xla_shape(), buf)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 384, in _check_special
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
FloatingPointError: invalid value (nan) encountered in grads_and_metrics_func
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 584, in _xla_call_impl
out = compiled_fun(*args)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 978, in _execute_compiled
check_special(name, out_bufs)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 378, in check_special
_check_special(name, buf.xla_shape(), buf)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 384, in _check_special
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
FloatingPointError: invalid value (nan) encountered in jvp(log_proba)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 584, in _xla_call_impl
out = compiled_fun(*args)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 978, in _execute_compiled
check_special(name, out_bufs)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 378, in check_special
_check_special(name, buf.xla_shape(), buf)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 384, in _check_special
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
FloatingPointError: invalid value (nan) encountered in jvp(log_proba)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 584, in _xla_call_impl
out = compiled_fun(*args)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 978, in _execute_compiled
check_special(name, out_bufs)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 378, in check_special
_check_special(name, buf.xla_shape(), buf)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 384, in _check_special
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
FloatingPointError: invalid value (nan) encountered in jvp(_einsum)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/schubert/projects/coax/doc/examples/pendulum/dsac.py", line 98, in <module>
File "/home/schubert/projects/coax/coax/policy_objectives/_deterministic_pg.py", line 159, in update
return super().update(transition_batch, None)
File "/home/schubert/projects/coax/coax/policy_objectives/_base.py", line 164, in update
grads, function_state, metrics = self.grads_and_metrics(transition_batch, Adv)
File "/home/schubert/projects/coax/coax/policy_objectives/_deterministic_pg.py", line 193, in grads_and_metrics
return super().grads_and_metrics(transition_batch, None)
File "/home/schubert/projects/coax/coax/policy_objectives/_base.py", line 233, in grads_and_metrics
return self._grad_and_metrics_func(
File "/home/schubert/projects/coax/coax/utils/_jit.py", line 80, in __call__
return self._jitted_func(*args, **kwargs)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/api.py", line 416, in cache_miss
out_flat = xla.xla_call(
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
return call_bind(self, fun, *args, **params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1635, in process
return trace.process_call(self, fun, tracers, params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 627, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 600, in _xla_call_impl
_ = clone.call_wrapped(*args) # probably won't return
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/schubert/projects/coax/coax/policy_objectives/_base.py", line 95, in grads_and_metrics_func
grads_func(params, state, hyperparams, rng, transition_batch, Adv)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/api.py", line 924, in grad_f_aux
(_, aux), g = value_and_grad_f(*args, **kwargs)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/api.py", line 995, in value_and_grad_f
ans, vjp_py, aux = _vjp(
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/api.py", line 2317, in _vjp
out_primal, out_vjp, aux = ad.vjp(
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/ad.py", line 118, in vjp
out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/ad.py", line 103, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 513, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/schubert/projects/coax/coax/policy_objectives/_base.py", line 68, in loss_func
self.objective_func(params, state, hyperparams, rng, transition_batch, Adv)
File "/home/schubert/projects/coax/coax/policy_objectives/_soft_pg.py", line 45, in objective_func
log_pi = self.pi.proba_dist.log_proba(dist_params, A)
File "/home/schubert/projects/coax/coax/utils/_jit.py", line 80, in __call__
return self._jitted_func(*args, **kwargs)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/api.py", line 416, in cache_miss
out_flat = xla.xla_call(
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
return call_bind(self, fun, *args, **params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1635, in process
return trace.process_call(self, fun, tracers, params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/ad.py", line 323, in process_call
result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
return call_bind(self, fun, *args, **params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1635, in process
return trace.process_call(self, fun, tracers, params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 202, in process_call
jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 311, in partial_eval
out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
return call_bind(self, fun, *args, **params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1635, in process
return trace.process_call(self, fun, tracers, params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 627, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 600, in _xla_call_impl
_ = clone.call_wrapped(*args) # probably won't return
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/schubert/projects/coax/coax/proba_dists/_composite.py", line 137, in log_proba
return self._structure.log_proba(dist_params, X)
File "/home/schubert/projects/coax/coax/utils/_jit.py", line 80, in __call__
return self._jitted_func(*args, **kwargs)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/api.py", line 416, in cache_miss
out_flat = xla.xla_call(
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
return call_bind(self, fun, *args, **params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1635, in process
return trace.process_call(self, fun, tracers, params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/ad.py", line 323, in process_call
result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
return call_bind(self, fun, *args, **params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1635, in process
return trace.process_call(self, fun, tracers, params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 202, in process_call
jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 311, in partial_eval
out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
return call_bind(self, fun, *args, **params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1635, in process
return trace.process_call(self, fun, tracers, params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 627, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 600, in _xla_call_impl
_ = clone.call_wrapped(*args) # probably won't return
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/schubert/projects/coax/coax/proba_dists/_normal.py", line 127, in log_proba
quadratic = jnp.einsum('ij,ij->i', jnp.square(X - mu), jnp.exp(-logvar))
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 4939, in einsum
return _einsum(operands, contractions, precision)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/api.py", line 416, in cache_miss
out_flat = xla.xla_call(
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
return call_bind(self, fun, *args, **params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1635, in process
return trace.process_call(self, fun, tracers, params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/ad.py", line 323, in process_call
result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
return call_bind(self, fun, *args, **params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1635, in process
return trace.process_call(self, fun, tracers, params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 202, in process_call
jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 311, in partial_eval
out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
return call_bind(self, fun, *args, **params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 1635, in process
return trace.process_call(self, fun, tracers, params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 627, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 600, in _xla_call_impl
_ = clone.call_wrapped(*args) # probably won't return
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 5070, in _einsum
operand = lax.dot_general(rhs, lhs, dimension_numbers, precision)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 730, in dot_general
return dot_general_p.bind(lhs, rhs,
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 272, in bind
out = top_trace.process_primitive(self, tracers, params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/ad.py", line 288, in process_primitive
primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/ad.py", line 465, in standard_jvp
val_out = primitive.bind(*primals, **params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 272, in bind
out = top_trace.process_primitive(self, tracers, params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 151, in process_primitive
return self.default_process_primitive(primitive, tracers, params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 158, in default_process_primitive
return primitive.bind(*consts, **params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 272, in bind
out = top_trace.process_primitive(self, tracers, params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/core.py", line 624, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 313, in apply_primitive
return compiled_fun(*args)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 337, in <lambda>
return lambda *args, **kw: compiled(*args, **kw)[0]
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 978, in _execute_compiled
check_special(name, out_bufs)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 378, in check_special
_check_special(name, buf.xla_shape(), buf)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 384, in _check_special
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
jax._src.traceback_util.UnfilteredStackTrace: FloatingPointError: invalid value (nan) encountered in dot_general
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/schubert/projects/coax/doc/examples/pendulum/dsac.py", line 98, in <module>
File "/home/schubert/projects/coax/coax/policy_objectives/_deterministic_pg.py", line 159, in update
return super().update(transition_batch, None)
File "/home/schubert/projects/coax/coax/policy_objectives/_base.py", line 164, in update
grads, function_state, metrics = self.grads_and_metrics(transition_batch, Adv)
File "/home/schubert/projects/coax/coax/policy_objectives/_deterministic_pg.py", line 193, in grads_and_metrics
return super().grads_and_metrics(transition_batch, None)
File "/home/schubert/projects/coax/coax/policy_objectives/_base.py", line 233, in grads_and_metrics
return self._grad_and_metrics_func(
File "/home/schubert/projects/coax/coax/utils/_jit.py", line 80, in __call__
return self._jitted_func(*args, **kwargs)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/api.py", line 137, in _nan_check_posthook
fun._cache_miss(*args, **kwargs)[0] # probably won't return
File "/home/schubert/projects/coax/coax/policy_objectives/_base.py", line 95, in grads_and_metrics_func
grads_func(params, state, hyperparams, rng, transition_batch, Adv)
File "/home/schubert/projects/coax/coax/policy_objectives/_base.py", line 68, in loss_func
self.objective_func(params, state, hyperparams, rng, transition_batch, Adv)
File "/home/schubert/projects/coax/coax/policy_objectives/_soft_pg.py", line 45, in objective_func
log_pi = self.pi.proba_dist.log_proba(dist_params, A)
File "/home/schubert/projects/coax/coax/utils/_jit.py", line 80, in __call__
return self._jitted_func(*args, **kwargs)
File "/home/schubert/projects/coax/coax/proba_dists/_composite.py", line 137, in log_proba
return self._structure.log_proba(dist_params, X)
File "/home/schubert/projects/coax/coax/utils/_jit.py", line 80, in __call__
return self._jitted_func(*args, **kwargs)
File "/home/schubert/projects/coax/coax/proba_dists/_normal.py", line 127, in log_proba
quadratic = jnp.einsum('ij,ij->i', jnp.square(X - mu), jnp.exp(-logvar))
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 4939, in einsum
return _einsum(operands, contractions, precision)
File "/home/schubert/miniconda3/tmp/envs/coax/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 5070, in _einsum
operand = lax.dot_general(rhs, lhs, dimension_numbers, precision)
FloatingPointError: invalid value (nan) encountered in dot_general
Looks like this is happening in Normal.log_proba, specifically in:
quadratic = jnp.einsum('ij,ij->i', jnp.square(X - mu), jnp.exp(-logvar))
Might be worth putting a breakpoint there.
Perhaps logvar is large and negative? If that's the case, it might mean that the entropy bonus is taking over the loss. In fact, that seems what the logs are telling us too.
Perhaps try a smaller entropy regularizer coefficient?
Perhaps try a smaller entropy regularizer coefficient?
Yes, that definitely helps, but the entropy keeps rising even with a tiny coefficient. And it only happens with the StochasticQ
functions. But also the learning performance of SAC is not the same as before, so I suspect that there still are some differences due to the refactoring.
Ok, I verified the implementation through this test against the old implementation:
def test_update_boxspace(self):
env = self.env_boxspace
func_q = self.func_q_type1
func_pi = self.func_pi_boxspace
transition_batch = self.transition_boxspace
q1 = Q(func_q, env)
q2 = Q(func_q, env)
pi1 = Policy(func_pi, env)
pi2 = Policy(func_pi, env)
q_targ1 = q1.copy()
q_targ2 = q2.copy()
updater1 = ClippedDoubleQLearning(
q1, pi_targ_list=[pi1, pi2], q_targ_list=[q_targ1, q_targ2], optimizer=sgd(1.0))
updater2 = ClippedDoubleQLearning(
q2, pi_targ_list=[pi1, pi2], q_targ_list=[q_targ1, q_targ2], optimizer=sgd(1.0))
q1_old = deepcopy(q1)
q2_old = deepcopy(q2)
q_targ1_old = q1_old.copy()
q_targ2_old = q2_old.copy()
old_updater1 = OldClippedDoubleQLearning(
q1_old, pi_targ_list=[pi1, pi2], q_targ_list=[q_targ1_old, q_targ2_old], optimizer=sgd(1.0))
old_updater2 = OldClippedDoubleQLearning(
q2_old, pi_targ_list=[pi1, pi2], q_targ_list=[q_targ1_old, q_targ2_old], optimizer=sgd(1.0))
params1 = deepcopy(q1.params)
params2 = deepcopy(q2.params)
function_state1 = deepcopy(q1.function_state)
function_state2 = deepcopy(q2.function_state)
updater1.update(transition_batch)
updater2.update(transition_batch)
old_updater1.update(transition_batch)
old_updater2.update(transition_batch)
self.assertPytreeNotEqual(params1, q1.params)
self.assertPytreeNotEqual(params2, q2.params)
self.assertPytreeNotEqual(function_state1, q1.function_state)
self.assertPytreeNotEqual(function_state2, q2.function_state)
self.assertPytreeAlmostEqual(q1_old.params, q1.params)
self.assertPytreeAlmostEqual(q2_old.params, q2.params)
self.assertPytreeAlmostEqual(q1_old.function_state, q1.function_state)
self.assertPytreeAlmostEqual(q2_old.function_state, q2.function_state)
So the problem is with the StochasticQ
function or rather finding the right hyperparameters. Do we want to defer this until after this change has been merged. I can remove the dsac.py
file as long as there is no working implementation.
I think the tree structure of in/out_axes must match those of its input (up to the level you provide). So you'll probably have to transform them using hk.data_structures.to_haiku_dict() and include all keys, i.e. also 'q'.
Ok, so I did try to implement the class using haiku's to_immutable_dict
method. But that won't work. The way that vmap
treats None
values in the in_axes specification is by replacing them with a special object
value. This uses the leaves
of the FlatMapping
, which are in turn computed using jax.tree_flatten
. This method discards the leaves with None
value. Here is a quick demonstration what works:
def fun(d):
return d['a'] + d['b']
vfun = jax.vmap(fun, in_axes=(hk.data_structures.to_immutable_dict({'a': 0, 'b': 0}),))
vfun(hk.data_structures.to_immutable_dict({'a': jnp.array([1, 2]), 'b': jnp.array([0, 3])}))
and what does not work due to the None
values.
vfun = jax.vmap(fun, in_axes=(hk.data_structures.to_immutable_dict({'a': 0, 'b': None}),))
vfun(hk.data_structures.to_immutable_dict({'a': jnp.array([1, 2]), 'b': jnp.array([0, 3])}))
So I have to convert the parameters and state with to_mutable_dict
before vmapping over them. Does this have any unwanted consequences @KristianHolsheimer
Hmmm, the deeper we get into it, the hackier it starts to feel. Also, from we may expect a slight performance drop (more memory, typically) by using vmap.
From the perspective of the jitted function, the double for loop is actually fine. The main problem is that ClippedDoubleQLearning doesn't make use of the same class inheritance as the other value-function updaters.. and that it looks too convoluted and ugly. But so far, the vmap implementation, although an elegant idea, starts to feel equally convoluted.
Yes, I agree. I think it might be better to just add the stochastic q-function parts to ClippedDoubleQLearning
and accept its current form.
This PR refactors
ClippedDoubleQLearning
to use the functions ofDoubleQLearning
as a preparation for Distributional SAC DSAC.