araffin / sbx

SBX: Stable Baselines Jax (SB3 + Jax)
MIT License
328 stars 32 forks source link

Allow to pass custom activation function in `policy_kwargs` #41

Closed paolodelia99 closed 6 months ago

paolodelia99 commented 6 months ago

Added the possibility to pass a custom activation function through the policy_kwargs argument when creating the following models: TD3, SAC, DDPG and DQN (like in sablebaseline3).

Description

Taking inspiration from stablebaseline, I've put the common code of the critic under the sbx\common\policy.py file since the critic code both in the sac module and in td3 module was the same.

Minor changes to .gitignore and Makefile have been made.

Motivation and Context

closes #37

Types of changes

Checklist:

Note: You can run most of the checks using make commit-checks.

Note: we are using a maximum length of 127 characters per line

paolodelia99 commented 6 months ago

quick question: why not PPO too?

You can already pass the activation function in PPO, I just kept things like in stablebaseline3.