KonduitAI / deeplearning4j

Eclipse Deeplearning4j, ND4J, DataVec and more - deep learning & linear algebra for Java/Scala with GPUs + Spark
http://deeplearning4j.konduit.ai
Apache License 2.0
11 stars 7 forks source link

RL4J partial fix for recurrent networks partial support #531

Closed aboulang2002 closed 4 years ago

aboulang2002 commented 4 years ago

What changes were proposed in this pull request?

There are 4 things that need to be done before fully supporting RNNs:

  1. The network's output of a given observation should be the same everywhere (in policy vs in the algorithm). It's always the case with non-recurrent networks but not with recurrent as the state of the network will change with every call to rnnTimeStep().
  2. All observations should be 'seen' by the network. Otherwise, the state will be incorrect. (This is still temporarily fixed in EpsGreedy).
  3. Algorithms that uses a shared "target" network: Since the network is shared, the state of that network may have been changed by another thread.
  4. Initial state after update / Replay memory based algorithms: After a network update and for each transition from the replay memory, the network's state must be set correctly (presently reset to zero). According to "Recurrent Experience Replay in Distributed Reinforcement Learning" (see https://openreview.net/forum?id=r1lyTjAqYX), the best way is to use a stale state and use a burn-in trajectory to bring the state to the best approximation we can.

This PR fixes point 1.

For the fix, a cache is needed in the networks implementations (in org.deeplearning4j.rl4j.network). Instead of adding a cache to every networks, the old implementations have been deprecated and a few new classes implementing the common behavior has been added (BaseNetwork, INetworkHandler, ComputationGraphHandler, MultiLayerNetworkHandler, CompoundNetworkHandler). Also two networks using BaseNetwork has been added (QNetwork and ActorCriticNetwork).

The new cache also brings a significant performance boost for all BaseNetwork implementations (see TMazeExample and AgentLearnerCartpole).

I also added a simple test environment DoAsISayOrDont. With TMaze, the rewards are too sparse to test small update batches. So I needed an environment with frequent reward that could not be solved with non-recurrent networks. With DoAsISayOrDont, the agent is supposed to do as told, or the opposite, and this directive is only "seen" when it changes. Also, this directive will switch randomly during the game. A +1 reward is awarded if the agent acted correctly, and -1 otherwise. A recurrent network will be able to learn what to do but a non-recurrent one will learn to do just one of the two directives (i.e. for 200 step episodes, the score will bounce around 100)

Use the included NStepRnn to try it. It should get to a score of ~100 rapidly, but then will bounce around that score for a while before the network figures what to do. Expect 550 episodes to get to the max score of 200 in the case of separate actor critic networks, and around 1600 episodes for a combined network.

How was this patch tested?

Manual and unit tests

Quick checklist

The following checklist helps ensure your PR is complete: