mobeets / q-rnn

0 stars 0 forks source link

delayed stateless cartpole #15

Closed mobeets closed 1 year ago

mobeets commented 1 year ago

training on 200 epochs as in R2D2 on the stateless cartpole but trained with a fixed delay, and then evaluated on 100 episodes with each below delay.

Below models trained on delay = 0, 1, and 2:

Curiously, for delays 1 and 2, the model does better when tested on a different (but lower) delay than it was trained on!

mobeets commented 1 year ago

Note: the time step discretization of cart pole is 0.02 s, or 20 ms. So a delay of 5 would probably be a reasonable delay for human behavior.

mobeets commented 1 year ago

After training on delay=2 (left) or delay=4 (right), but including the previous action, we can get better behavior for the trained delay than with zero delay:

(note: 500 is max episode length.)

One caution though, which I just realized: the previous action input does not have sensory delay. This might be biologically reasonable given an efference copy, but something to consider.

mobeets commented 1 year ago

Given a model trained with some delay, I then build a regression model using the LSTM's hidden or cell activity, Z, to predict either the observations, Y_obs, or the true (hidden/future) state, Y_state. On cross-validated trials, a delay=2 model did appear to show a stronger correlation between the future states than the actual observation. This suggests the model really is learning to predict the future state!

Z_hidden->Y_obs, R^2: 0.943
Z_hidden->Y_state, R^2: 0.997
Z_cell->Y_obs, R^2: 0.943
Z_cell->Y_state, R^2: 0.998

Got the same result for delay=3:

NEVERMIND! Must be a conceptual bug, because even an untrained model shows the same results

mobeets commented 1 year ago

After fixing bug, here's a model trained on delay=3:

Compare to delay=0: Above, the middle panel uses an env with delay=0 for rollouts, and the right panel uses delay=3. This is because for the delay=0 env, the model's control is probably really good, meaning the state is very stable, meaning you can easily predict states into the future (because you predict it's the same as it is now!). So we need to match the rollout env to the one for the delay=3 model. And in this case, the delay=3 model has higher $R^2$. but it also has higher performance, so what we should really do is use a random policy for rollouts.

Repeating again using a random policy in a delay=0 env (left) and delay=3 env (middle), for the model trained on delay=0 (top) and delay=3 (bottom):

^ including the performance for each model when used as the policy, since these are both retrained since the previous panels.

mobeets commented 1 year ago

I think this task is not the right one to use for this question, because our policy changes the state distribution! so the better the policy controls the state, the more predictable it will be in the future.

so i think my "catching" task will be better.