tensorforce / tensorforce

Tensorforce: a TensorFlow library for applied reinforcement learning
Apache License 2.0
3.3k stars 530 forks source link

allow training the network with also state prediciton goal #559

Closed jerabaul29 closed 5 years ago

jerabaul29 commented 5 years ago

If I understand well, training the network to not only predict the best action or policy but also the evolution in time of the system helps:

https://worldmodels.github.io/

http://www0.cs.ucl.ac.uk/staff/d.silver/web/Teaching_files/dyna.pdf

In all the following I will be reasoning with PPO in mind, but I guess other methods would work the same.

While I guess implementing the full world model / dreaming would be a lot of work, would it be easier at least in a first time to add a flag that trains the network to both perform the policy update and simultaneously optimize some loss function that predicts the next state, in some parallel output. This would simply mean deterministically extending the output size of the network by the size of the state, and add a component to the training loss.

What I mean is to do something like:

output of network = {policy description} + {predicted next state}

where + is a simple concatenation

and the loss would be:

total loss = PPO loss {policy} + L1 (predicted next state - observed next state}

?

I guess there will be a couple of metaparameters (like weighting of the two losses and learning rate of new additional problem).

Does that seem reasonable or is it a bad idea? Would it be possible to implement something in this style, with a simple flag to use it or not (like state_prediction_flag or something named in this kind).

AlexKuhnle commented 5 years ago

Working towards a generic world-model-style generative memory is one of the longer-term development goals, but as you say, may take some time. Adding a simpler next-state-prediction auxiliary loss is less problematic, although proper integration requires some thinking. Initially, it's probably best to just sub-class the corresponding model/agent class and add this auxiliary loss, and we can then work towards integrating the feature as a generic configuration option. Happy to hack something and see how it goes. :-)

AlexKuhnle commented 5 years ago

Added to roadmap