EdanToledo / Stoix

🏛️A research-friendly codebase for fast experimentation of single-agent reinforcement learning in JAX • End-to-End JAX RL
Apache License 2.0
230 stars 24 forks source link

[FEATURE] Allow CNN + MLP torsos before the action heads #109

Closed roger-creus closed 2 months ago

roger-creus commented 2 months ago

Currently, several CNN-based torsos: e.g. CNN and VisualResnet, do not allow for an MLP layer between the outputs of the convolutional features and the action head (either CategoricalHead or DiscreteQNetworkHead). I think this is important for reproducing the architectures that are widely used where there is 1) A convolutional feature extractor; 2) an MLP projector; and finally 3) the action head.

It would be great to be able to combine torsos in the config files (I am not 100% sure if this is doable already?).

An alternative solution would be to modify the action heads (e.g. CategoricalHead, DiscreteQNetworkHead) to take layer_sizes as optional inputs to pre-append them before the output layers.

Edit: I realized VisualResNet actually has this, so it's only missing for the standard CNN

EdanToledo commented 2 months ago

Hey, you are 100% right and this is a feature that I've been meaning to do actually (I just forgot about it ... ) It's quite a small change so I'll try get around doing it sometime Sunday. Thanks for reminding me :)

EdanToledo commented 2 months ago

So basically, I've been thinking about it and i actually want to maybe do a slightly big change on how networks are constructed - it would be nice to easily be able to chain together any and all torsos but I'm not exactly sure on the best way to do it. I have no idea though when i will get around to this. For now, I've simply added the MLP torso to the end of the CNN torso and put in the config a hidden_sizes argument like that of the visual resnet. I hope this is okay for now. If you want to have a more complicated torso, i'd recommend just simply copying and pasting the most relevant one and editing it there and making a new config for it. I'll merge the PR when it passes the tests.