araffin / sbx

SBX: Stable Baselines Jax (SB3 + Jax)
MIT License
328 stars 32 forks source link

[Question] Extending sbx algorithms (e.g via a callback) #15

Closed asmith26 closed 1 year ago

asmith26 commented 1 year ago

Hi there,

I'm trying to experiment with "RL while learning Minmax penalty" (paper, code), and I thought I'd try adding it to a sbx Droq setup. From the paper, the implementation looks quite straightforward, essentially:

for each step:
    penalty = minmaxpenalty.update(reward, Q[state])
    if info["unsafe"]:
        reward = penalty

hence I need to obtain the Q-value. I've been looking into the Droq code and I believe the Q-value is computed at (?) https://github.com/araffin/sbx/blob/b8dbac11669332c8f8ad9846acb1b6e8bfcd7460/sbx/tqc/tqc.py#L282 I've also been looking into trying to implement this via a StableBaselines callback, but can't seem to get it to work (not sure if this is a suitable use-case?)

Many thanks for any help, and for this fantastic lib! :)

Checklist

araffin commented 1 year ago

hence I need to obtain the Q-value. I've been looking into the Droq code and I believe the Q-value is computed at (?)

You should probably take a look at SAC first, this DroQ implementation is based on TQC (with quantiles) and this makes it slightly more complicated (just a mean along the correct axis) to have the q-value. SAC in SBX can also be used in the DroQ configuration (need to activate dropout + LN and add policy delay with several gradient steps).

From the paper, the implementation looks quite straightforward, essentially:

is that during data collection only? Otherwise the easiest would probably be to fork the repo.

asmith26 commented 1 year ago

Thanks for the tip regarding SAC first, good idea.

is that during data collection only?

From the paper it says: We propose a simple model-free algorithm for estimating this penalty online, which can be integrated into any RL pipeline that learns value functions

Many thanks again for your help!