keras-rl / keras-rl

Deep Reinforcement Learning for Keras.
http://keras-rl.readthedocs.io/
MIT License
5.52k stars 1.37k forks source link

model.add(Flatten(input_shape=(1,) + env.observation_space.shape)) #19

Closed yongduek closed 8 years ago

yongduek commented 8 years ago

Would you please give just a short explanation about this Flattening used in dqn_cartpole.py?

model.add(Flatten(input_shape=(1,) + env.observation_space.shape))

Maybe it it due to a generalization for various problem environments but it is not easy to figure it out.

Compared to the corresponding part in dqn_atari.py, it seem that (1,) in the code corresponds to the size of the time window. So, it seems that the input shape (1,4) = (1,)+(4,) in this case is flattened to be a vector of 4 elements (something like array([1,1,1,1])) through the flattening operation.

matthiasplappert commented 8 years ago

Exactly, it is related to the window length. Let's assume that your observations have shape (10, 20). If window_length=4, the input to the network would be of shape (4, 10, 20). If you use flatten or not depends on what you want to achieve. In the Atari example, the convolution is performed before flattening since the same set of filter should be used for each frame. In other examples, it might make more sense to flatten immediately.

In the case of the cartpole example, the window length is 1 (since the observation already contains velocities etc.). Flattening the input thus simply reduces the dimension by one and otherwise keeps the observation identical.

Does this clarify things?

matthiasplappert commented 8 years ago

Closing this since the question has (hopefully) been answered. If you this needs further clarification, please re-open the issue.

yongduek commented 8 years ago

Hello, the question was answered; sorry for not responding due to my tight work schedule. I should have left some message. Thanks a lot for the answer.

Yongduek Seo Professor +82 10 9296 8896 Department of Media Technology Sogang University, Korea

On 6 September 2016 at 22:21, Matthias Plappert notifications@github.com wrote:

Closed #19 https://github.com/matthiasplappert/keras-rl/issues/19.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/matthiasplappert/keras-rl/issues/19#event-779320444, or mute the thread https://github.com/notifications/unsubscribe-auth/AD3huKH15Fc3-hW7tWFlwbmrC39T88Hnks5qnWjygaJpZM4JtAFK .