mobeets / q-rnn

0 stars 0 forks source link

move to policy gradient approach? #22

Open mobeets opened 6 months ago

mobeets commented 6 months ago

Remember that one benefit of policy gradient over Q-learning is that it can learn a stochastic policy. and then we don't have to finetune the exploration during training.

note that if we use softmax exploration with Q learning, we're basically ignoring that when we train, because we're updating Q(s,a) and not the probabilities of Q(s,a). [also instead of taking max a Q(s(t+1), a), we should probably be using the softmax here as well.]

so i feel like it could be good to try to do A3C with an RNN, which rllib supports. the question is whether adding a KL penalty would be straightforward or not. see code here. the question is how we would pass in the current marginal policy, since this is part of the loss. (i think we would add something to postprocess_trajectory, in the code file above)

edit: note that postprocess_trajectory is only applied at the end of the episode, and similarly I think maybe the loss function is too. so we would need to calculate a time-varying marginal policy in postprocess_trajectory.

mobeets commented 6 months ago

though of course, we would still have the problem that we probably don't want the KL penalty applied when taking a null action (see #20)

mobeets commented 6 months ago

SAC-D

i read this one post about how in RL, often it's better/easier to have a single-file implementation of a given algorithm, rather than using a full-on generalized library, because it's then very hard to understand how it works! i think this is a great point, and especially relevant given that we want to make a few atypical changes:

anyway, one option is to use pomdp-baselines, that Ni/Eysenbach paper on using RNNs in POMDPs. they implement TD3 (which is a deterministic policy, so we don't want that), and SAC. SAC was originally implemented for continuous action spaces only, but there is SAC-D, which is what we want (and they have that implemented as well).

overall, it would take some work, but this seems like a reasonable repo to start with. check their example.ipynb file for reference.

and checking out the paper, I think it really is pretty straightforward: We just replace every occurrence of log π(s(t)) with log π(s(t)) - log P, where P is the marginal policy vector, updated with a running average.

Soft Q-learning

not clear to me how these methods, differ from one another, but: https://arxiv.org/pdf/1512.08562.pdf http://proceedings.mlr.press/v70/haarnoja17a/haarnoja17a.pdf

sam says lucy has been using the method from the fox paper

A2C

alternatively, we could use A2C, which is often implemented with an LSTM. this is what seems to be the most-used implementation.

or this repo, which has the benefit that it only implements A2C and nothing else.

marginal policy

one thought is that, in actor-critic methods, part of computing the loss involves getting the action_log_probs across the whole episode. those could be directly accumulated (with exponential smoothing) to get the marginal policy at each time step.

one question though is whether we are supposed to keep the gradients in there? e.g., is our penalty just an intrinsic reward, detached, or should we be able to backprop and see how that intrinsic reward relates to the entropy of our outputted q values?