miyosuda / async_deep_reinforce

Asynchronous Methods for Deep Reinforcement Learning
Apache License 2.0
592 stars 192 forks source link

Problem while using the code #1

Open originholic opened 8 years ago

originholic commented 8 years ago

Hello @miyosuda,

Thanks for sharing the code, please ignore the title, I tried out your code with the control problem of cartpole balance experiment instead of Atari game, it works well. But few questions want to ask.

I am curious, in the asynchronous paper, they also used another model implementation with 1 linear layer, 1 LSTM, layer, and softmax output, I am thinking of using this model to see whether improve the result, can you suggest how the LSTM can be implemented using tensorflow in the case of playing atari game?

Also wondering that the accumulated states and reward were reversed, do you need to reverse the actions and values as well? Although it did not make any different when I tried out, just wondering why.

states.reverse()
rewards.reverse()

Last, do you really need to accumulate the gradient and then apply the update, since tensorflow can handle the 'batch' for update.

miyosuda commented 8 years ago

I'm really glad to know that you've succeeded to reproduce continuous model.

About the reverse(), as you say, I've forgot to add

actions.reverse()
values.reverse()

in addition to

states.reverse()
rewards.reverse()

and I've pushed the fix just now. Let me explain why there are reverse() of lists. In pseudo code of A3C algorithm in deep mind paper, there is

for i in {t-1, ...., tstart} do

This means that "i" will decrease like

t-1, t-2, t-3 ... tstart

This is why I put these reverse() for lists collected in the loop.

As far as I tried with Atari model, I still couldn't reproduce good learning result. I was planning to implement LSTM after being able to reproduce good Atari result, but as you say that you've succeeded with continuous model, I should try LSTM now. Please wait.

About the batch, let me think about whether it is possible to replace it with batch or not. Just a moment please.

Thank you for suggestions!!

originholic commented 8 years ago

Many thanks for your reply, and glad to hear that you also plan to work on the LSTM model.

I just uploaded the testing codes (based on this repo) for the "batch" update that mentioned. https://github.com/originholic/a3c_vrep.git

I only tested it with the cartpole balance domain, but somehow I found it actually take longer time to reach the desired score than your implementation. I will try to investigate this later, as now I will continue to work with your implementation to study the LSTM model, which I am not familiar with.

Also Instead of constant learning rate:

math.exp( log_lo * ( 1-rate ) + log_hi * rate)

Don't know whether the random initialization of learning rate that mentioned in the paper can help to improve the results:

math.exp( random.uniform( log_lo, log_hi ) )
miyosuda commented 8 years ago

I just uploaded the testing codes (based on this repo) for the "batch" update that mentioned. https://github.com/originholic/a3c_vrep.git

Thanks! I'll try.

About the LSTM, I'm also new to LSTM and just started studying it recently, so please don't expect too much! However I'm really interested in 3D labyrinth model with LSTM, so I would like to try it.

About the randomizing learning rate with log_uniform, I also used to randomize initial learning rate for each threads with log_uniform before. However when I looked at the figure on page 23, I found that learning rate varies from 10^-4 to 10^-2, uniformly distributed with log scale sampling.

So my understanding of log_uniform function is that they are using log_uniform for finding best hyper parameter when they use grid search for it.

(In the graph on page 14 and page 22, they are also using log scale for grid searching parameters.)

However, I'm not sure my understanding is correct.

originholic commented 8 years ago

Thanks for pointing out.

After re-think about the random initialization, I think you are right about it, the initial learning rates sampled from the LogUniform range were used to demonstrate the sensitivity of their methods. And it makes sense that a constant (or best choice) of learning rate is apply for RMSPROP and decay to zero over time.

Sorry, it is my bad, just got confused when they use the phrase in the paper.

"each using a different random initialization and initial learning rate"

miyosuda commented 8 years ago

No problem. Any suggestion and discussions are always welcomed. Thanks!

originholic commented 8 years ago

Hi, tuning in again. May I ask that in the case of continuous action domain of the asynchronous paper, they used two policy outputs of a linear layer and a Softplus activation + linear layer to represent the mean and variance. I am wondering how policy loss can be calculated with two outputs?

self.policy_loss = -( tf.reduce_sum( tf.mul( tf.log(self.pi), self.a ) ) * self.td + entropy * entropy_beta )

I am thinking of calculating the loss separately, by making 2 of the above policy loss function for the outputs. Does this make sense to you? Sorry this might be out the scope of your interest, as 3D labyrinth doesn't require continuous actions, but any suggestion is highly appreciated. Many thanks!

miyosuda commented 8 years ago

I have never tried continues model, so today I looked into other Simple Cart-Pole Actor-Critic sample without NN to learn about it. How to define loss for policy for continues action is still difficult for me, so I'll try continues model with creating branch. (I'm also interested in continuous model too) Maybe it will be natural to make two different loss function for mean and variance, but I'm not sure now. I'll try to figure out.

By the way, even with the discrete action model I'm implementing now, the policy loss function is the most difficult part for me and still I'm not certain this is 100% correct.

However, when I tried simple 2D grid maze model (which I implemented in debug_maze branch), this program succeeded to find shortest path with this policy loss function. So the loss function for discrete action seems fine.

Anyway, I'll report here if I found any result with continuous model.

originholic commented 8 years ago

Thanks for the reply. As far as I know from your codes, the policy loss function for discrete domain is calculated using the negative log-likelihood of the softmax function.

After doing some searches, may be I can apply the same loss function eg. negative log-likelihood, but instead of softmax function, a Gaussian (Normal) distribution function can be used instead since the outputs have mean and variance. So I think the loss function looks like by following the formula, where sigma2: variance, and mu: mean,

D = tf.to_float(tf.size(self.a))
x_prec = tf.exp(-tf.log(self.sigma2))
x_diff = tf.sub(self.a, self.mu)
x_power = tf.square(x_diff) * x_prec * -0.5
gaussian_nll = (tf.reduce_sum(tf.log(self.sigma2)) + D * tf.log(2 * np.pi)) / 2 - tf.reduce_sum(x_power)
self.policy_loss = gaussian_nll * self.td + entropy_beta * entropy

Sorry for the messy typing, I will try this out to see whether it works for the continuous cartpole domain, and let you know how this goes. Thanks

miyosuda commented 8 years ago

Is this the explanation of this loss function?

http://docs.chainer.org/en/stable/reference/functions.html#chainer.functions.gaussian_nll

I really want to know the result. There is a lot to learn from this thread for me. Super thanks!!

originholic commented 8 years ago

Yes, that's right, the negative log-likelihood of normal distribution is from the chainer site, but I also found another called maximum log-likelihood, I think they are the same thing by looking into the formula alone. Same here, there are lots methods out there waiting to be learned and get confused.

Tried the loss function based on your code, it works moderately well with cart-pole balance task of continuous action domain, at least it is able to converge (or said reach the desired score). But possibly need some more examples to study the codes in order to draw conclusion that the loss function actually works for continuous action. So keep working on it!! Thanks.

However when I turned back to try it with the "batch" method, it reached a score of around 2000(desired score was 3000), the network somehow diverged immediately(I am not pretty sure it was diverged or not, or explode, the network just gave "NaN" for its output all the time).

miyosuda commented 8 years ago

Thank you for reporting. I was trying batch with my discrete action code in "batch" and "debug_maze_batch" branch. I'm checking whether gradient accumulation is working correctly when batched.

ghost commented 8 years ago

@miyosuda: Hey, I had been trying to implement the same, on Theano. Implemented an A2C version (single thread), which obviously never converges inspite of training on GPU, for even a week or so... Came across your git source. Could you please let me know what exactly are the issues that you are facing right now, that makes your learning still not as good as required? Is it NaNs and stability; or no-convergence of the network? We can try to catch up on this, as I am also in urgent need to have an Actor Critic learner on Pong.

miyosuda commented 8 years ago

@aravindsrinivas Thank you for joining the discussion. Let me explain what I tried, what I succeeded and what I have not succeeded yet.

I have beed trying pong with A3C with CPU 8 threads. The problem is that the score of the game does not increase even with one or two days learning. The AI can hit back tree or four times in one game, but the score does not increase like the deep mind paper shows.

(As far as I'm trying with pong, the network does not diverges like NaN)

To confirm whether my implementation has a problem or not, I tried easier task. I implemented 10x10 grid 2D maze, and let this A3C algorithm find the shortest path. After running two or thee minutes, the AI converged to optimal result. (It succeeded to find the shortest path)

I tried this in "debug_maze" branch.

After confirming that this algorithm can solve easy RL task, I'm changing hyper parameter little by little to check whether the game score will increase like paper shows. But the result is still same.

I once heard that the DQN is very sensitive with hyper parameters, and as far as I see the paper, hyper parameters of this method seems sensitive.

Along with tuning hyper parameters, I'm also planning to try another task, task that doesn't use CNN.

By the way, the key concept of this method is to get the stability of the network by running multiple threads at the same time, not to diverge or oscillate. So if you have problem with single thread, how about trying multiple threads?

I have never tried Theano, but if you would like to run it with TensorFlow, I can help you.

ghost commented 8 years ago

@miyosuda I mailed the authors (from DeepMind). These are some hyper parameters that they explicitly told me in the mail:

The decay parameter (called alpha in the paper) for RMSProp was 0.99 and the regularization constant (called epsilon in the paper) was 0.1. The maximum allowed gradient norm was 40. The best learning rates were around 7*10^-4. Backups of length 20 were used which corresponds to setting the t_max parameter to 20.

Also, I am not sure if you used the frame skip in your implementation. From what I saw in the gamestate.py, you just have a reward = ale.act(action)? Shouldn't it be in a for loop like for in range(frame_skip): reward += ale.act(action)

Also, are you clipping the reward to lie between -1 and 1? In DQN, rewards were clipped between -1 and 1. I am not sure what the rewards are for Pong from the ALE src.

miyosuda commented 8 years ago

Wowwowow! They are the parameters that I really wanted!!! Super thanks!!

I was always using t_max with 5 and didn't use gradient norm clipping. (In the paper there was only one line just referring gradient norm clipping, so I didn't tried it)

also, I am not sure if you used the frame skip in your implementation

I've set the frame skipping with every 4 frames at "ale.cfg" file, but as you say it might better to put loop as you suggested.

I'm not clipping the score, but ALE pong gives reward 1 or -1, so it will be ok.

Anyway, super thanks for giving me such a valuable information!!! I'll try these parameters.

ghost commented 8 years ago

@miyosuda Another question: How exactly are you synchronizing the RMSProp parameters?

miyosuda commented 8 years ago

I'm accumulating gradient t_max times in each thread, and after that I'm applying these accumulated gradients with shared RMSProp. When applying accumulated gradient, "rms" parameter is shared among threads. (The "rms" parameter in TensorFlow corresponds to "g" in the paper.) The "momentum" parameter in RMSProp can be shared, but I'm not using momentum in RMSProp because there was no referring with momentum in RMSProp in the paper. (I'm applying 0.0 as momentum constant in RMSProp)

When applying accumulated gradients with shared RMSProp, I'm not using any synchronization like mutual exclusion among threads.

(Is this what you are asking?)

As far as I see the source code of TensorFlow, it seems ok to apply gradient without lock when running on CPU. (To run it on GPU, I need to research more to check whether we can implement shared RMSProp with GPU or not, because memory handling on GPU might be different from CPU.)

ghost commented 8 years ago

Shouldn't we lock another thread from updating the parameters of the global network, when one particular thread is already updating it with its accumulated gradient from t_max steps?

My question was related to RMSProp previous gradient values. We do a moving average of the RMS of the gradients right? And the RMS is used to determine our update of the parameters. My question is: Would the gradient values of different threads all be used together to update the moving average of the RMS? Or do we have separate moving averages for each thread, which is used when that corresponding thread is updating the parameters using its accumulated gradient?

In the paper, they consider both the approaches, but say that having separate RMSProp parameters (mainly the moving average) has less robustness than sharing the moving average. But they don't reveal how exactly they synchronize the moving average across threads.

Could you explain what you are doing?

miyosuda commented 8 years ago

@aravindsrinivas Sorry my mistake, while checking my code, I found that the moving average of RMSProp is not shared. So my current implementation is not shared RMSProp.

I've created RMSPropApplier class in rmsprop_applier.py In this class, slot named "rms" corresponds to parameter "g" in the paper.

(The "rms" slot parameter will be passed to native code around here) https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/training_ops.cc#L143

I created this class to share this "rms" parameter among threads, but I found that RMSPropApplier class is created in each thread in a3c_training_thead.py.

So the moving average is calculated differently in each thread. I need to fix this. Sorry about that.

ghost commented 8 years ago

@miyosuda

Hi, I also confirmed that 1) the critic learning rate must be half the actor's.. 2) the LR must be linearly annealed to 0 over the course of training. 3) the parameters 'g' and 'theta' (moving average of RMS of gradients and of course the parameters) are shared across the threads. (Unlike your earlier version of having separate RMS moving averages). Also, there is no need of locking and updating. 4) t_max = 20 means 20 perceived frames (80 with Frame skip as per game)... Not 20 states .. ie not 20 84_84_4 tensors, but rather 20 84*84 frames...

A question: Could you tell me at what speed (steps per second) [where step refers to a decision taken from the network in gameplay] the code runs, for the 8 thread version? DeepMind says they get 1000 steps/sec from 16 threads, and thus for a single thread, it should be 70. But I was never able to run at 70 for my single thread code.. It used to run at 30.

miyosuda commented 8 years ago

@aravindsrinivas

Thank you for providing such a valuable information again.

1) the critic learning rate must be half the actor's..

I got it. I'll set the LR for actor starting from 710^-4, and 3.5 \ 10 ^ -4 for critic.

2) the LR must be linearly annealed to 0 over the course of training.

I got it. I've already implemented LR annealing.

3) the parameters 'g' and 'theta' are shared across the threads.

I see. I'm now testing sharing 'g' in "shared_rmsprop" branch. Later I'll merge this to master branch after confirming. In my implementation, 'theta' corresponds to variables in global_network instance.

4) t_max = 20 means 20 perceived frames

I wanted to ask about this too. I used to implement frame skipping with "ale.cfg" file with "frame_skip=4" option. When using this option, every time we call ale.act(chosen_action), frame will advance 4 frames.

So I was storing frames for each state during one backup sequence (sequence of 5 states) like this.

(pattern A)
state[0] = { 0  4  8 12}     <- frame0, 4, 8, 16
state[1] = { 4  8 12 16}
state[2] = { 8 12 16 20}
state[3] = {12 16 20 24}
state[4] = {16 20 24 28}

With this pattern, each adjacent states are sharing three perceived frames.

Another way to store frames with 4 frame skipping is

(pattern B)
state[0] = { 0  4  8 12}
state[1] = {16 20 24 28}
state[2] = {32 36 40 44}
state[3] = {48 52 56 60}
state[4] = {64 68 72 76}

If we choose pattern B, one chosen action will continue along with 16 frames. How should we implement frame skipping with t_max=20? If you have any idea about this, please teach me.

About the speed of running steps, I'll check it on my environment and please wait a minute!

miyosuda commented 8 years ago

@aravindsrinivas I've checked the running speed. I'm outside now, so I checked it with my MacBookPro (Intel Core i7 2.5GHz).

It was 106 steps per second with 8 threads. So it runs 13 steps per thread. I have another Core i7-6700 Desktop machine, and I remember it was x1.5 (or x2?) times faster than my MacBookPro. (I'll check with Core i7 machine later)

Anyway, speed on my environment is much slower than DeepMind's.

ghost commented 8 years ago

@miyosuda

That's quite slow I guess... Maybe I got 30 steps per second for single thread, because of the GPU. I can't understand how DeepMind got it working with 70 steps/sec for a single thread. That's actually almost as fast as running DQN on GPU. So, your code is about 5 times slower than DeepMind's I guess... But we can still reproduce results with 1-2 days of running.....

ghost commented 8 years ago

@miyosuda When I implemented, I had it the same way as Pattern - A (0,4,8,12), (4,8,12,16), (8,12,16,20) , .... Even in DQN, that's the way they do it.

What we should do is - say we are '0', we take an action, repeat it 4 times. We would execute 0->1, 1->2, 2->3 and 3->4 using the same action that was decided at 0. We now again decide on action at '4', execute 4->5, 5->6, 6->7, 7->8 (4 repetitions) and decide on an action at '8', .. and so on.

Our states would be (0,4,8,12); (4,8,12,16); (8,12,16,20) .... . Since they say tmax is equivalent to 20 perceived frames, we must stop at (64,68,72,76). ie , you stop once you decide on an action at 76th frame, and repeat it 4 times to get to the 80th frame. 80th frame (with past 3 perceived frames 68,72,76) would be our s{tmax} , which is used to calculate our target through V(s{tmax}). We would have 17 tuples (0,4,8,12) , (4,8,12,16) ..., (64,68,72,76) for s{t} for t = 0 to t_max - 1. s_t_max would be (68,72,76,80)..

I will actually try to implement a Theano version of this now that so many details are clear.. Please keep updating on whether you are able to implement it.

joabim commented 8 years ago

Hi, I also found this repo after trying to implement A3C from the DeepMind article, it's nice to see progress! However, when running the implementation the agents seem to only perform three actions from the legal action set given from the ALE interface and these actions correspond to idle, fire and right. Could this be a result from the provided pong binary being problematic or the set ACTION_SIZE of 3? The reason I'm asking is that when displaying the results from training a few hours, the paddle is stuck at the environment edge of the pong playing field.

ghost commented 8 years ago

@joabim I think it is because the ACTION_SIZE is set to 3. He is using only the legal actions allowed for the Pong game, and Pong has only 3 actions (moving up/down/staying idle).

joabim commented 8 years ago

@aravindsrinivas You're right! But for some reason, instead of up/down/idle my runtime printouts seem to suggest that the agents perform the actions noop/idle (0), fire (1) and right (3) (which corresponds to up when testing pong.bin in the the stella emulator) according to the arcade learning environment documentation but maybe I'm misinterpreting the minimal action set. Do you get a moving paddle?

miyosuda commented 8 years ago

@aravindsrinivas Now I understand what you mean. I'll try that way too. Thanks!

@joabim Thank you for joining the discussion. It seems strange to get [0, 1, 3] from pong game rom.

I tried this code,

from ale_python_interface import ALEInterface
ale = ALEInterface()
ale.loadROM("pong.bin")
real_actions = ale.getMinimalActionSet()
print "minimal actions=", real_actions

and I got the result

minimal actions= [0 3 4]

[0, 3, 4] means [idle, right, left] Could you try the code above?

miyosuda commented 8 years ago

@joabim Ah there is another function named getLegalActionSet() in ale, and I also tried it.

from ale_python_interface import ALEInterface
ale = ALEInterface()
ale.loadROM("pong.bin")
leagl_actions = ale.getLegalActionSet()
print "legal actions=", leagl_actions

and the result was

 legal actions= [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17]

So I think getLegalActionSet() is just returning all default actions. Which function are you using, getMinimalActionSet() or getLegalActionSet()?

joabim commented 8 years ago

@miyosuda Exactly, when I invoke getMinimalActionSet() I get

[ 0 1 3 4 11 12]

so I'm wondering what has happened. I have tried rebuilding ALE but it doesn't change. When just running the code, the score remains at -21. By setting (forcing)

self.real_actions = [0, 3, 4]

I actually get some results from the training as to be expected. It's really weird that I can't get the real action set from ALE!

miyosuda commented 8 years ago

@joabim Are you using same rom as I'm using? And the file name of the rom should be "pong.bin" (Because ALE seems detecting the game rom type from the file name)

joabim commented 8 years ago

@miyosuda Indeed I am! However, I'm using python 3.5 and anaconda, there could be some problem using loadROM(...) and bytes literal as input (I had to prefix the "pong.bin" with b to get it working), the ALEInterface displays the correct information in the terminal though... I could set up a python 2.7 environment and try again

zhuchiheng commented 8 years ago

@originholic Hi , I adjusted the epsilon in your implement to 0.1 and it converged at last(T=1901383 in my test). Thanks to the hyper parameters from @aravindsrinivas

And, I don't know why you @originholic would use a lock for the "env" in the training thread. Each thread has its own "env" object from what I understand they don't need the lock...

ghost commented 8 years ago

I think the entropy beta must also be set to 0.01. It is 0.1 in the constants.py file. In the paper, it is mentioned as 0.01.

originholic commented 8 years ago

@zhuchiheng Thanks for trying out the code, yes the lock for the env is not actually required, it is there because I copied directly from my other project, and didn't manage to clean up the code, and I think it can still reach the desired scores as long as the epsilon is larger than 0.001 if you run the cart-pole environment. Also need to fix the initial learning rate instead of random initialization, probably reach the desired score faster.

@miyosuda, @aravindsrinivas Wow, I think I will have a lot to catch up since been away for a while, and very thanks @aravindsrinivas for the hyper-parameters and helpful suggestions, regarding to the speed, I agree that mentioned due the GIL of python... it is not able to utilize the CPU efficiently for multi-threading, possibly instead of using threading module, the multiprocessing library can probably help to speed up.

miyosuda commented 8 years ago

@aravindsrinivas I see. As you say, entropy regularization constant is written as 0.01 in page 11.

By the way, you wrote in previous post that

The decay parameter (called alpha in the paper) for RMSProp was 0.99 and the regularization constant (called epsilon in the paper) was 0.1.

What does "epsilon" in this comment mean? Epsilon term in RMSProp calculation in equation (9) in page 9? (I thought that epsilon in RMSProp is a small constant like 1e-10 to avoid zero division, but if your comment means this term, I'll try it!)

And did you hear anything from DeepMind author about discount factor gamma?

miyosuda commented 8 years ago

@originholic Thank you for introducing multiprocessing library in python. I didn't know this. Let me check it.

miyosuda commented 8 years ago

@joabim I found why your minimal action set size differs from mine.

ALE seems to have changed minimal action set for pong one month ago.

https://github.com/mgbellemare/Arcade-Learning-Environment/commit/e1c811a50848fa0165d041bbdbdc0ae7bc116f31

I'm now using forked version of ALE on which I added some modification in order to run on multithreaded environment.

https://github.com/miyosuda/Arcade-Learning-Environment

This version still used 3 actions. I don't know why they changed action size from 3 to 6.

joabim commented 8 years ago

@miyosuda Thank you so much! That explains it!

@aravindsrinivas Awesome work on the hyper parameters! I wonder if the source code for this project will be released some time like they did with DQN in the Human-level control through deep reinforcement learning article

ghost commented 8 years ago

@zhuchiheng Hi, I have been running this for a million (T = 109200).. But it still scores not more than 18 (mostly 20 and 21).... Can you tell me how was the trajectory of the scores for you over T?

ghost commented 8 years ago

@miyosuda
Does the code work for you now?

miyosuda commented 8 years ago

@aravindsrinivas Simple grid maze task converged to optimal result easily with this code, but the pong's result is same as yours (not more than 17 and mostly 20 and 21 after one day learning)

ghost commented 8 years ago

Did you mean -17 and -20/21? Also, what's the value of T you get after a day? 1.2 million is too slow :-(

On 3 May 2016 09:53, "Kosuke Miyoshi" notifications@github.com wrote:

@aravindsrinivas https://github.com/aravindsrinivas Simple grid maze task converged to optimal result easily with this code, but the pong's result is same as yours (not more than 17 and mostly 20 and 21 after one day learning)

— You are receiving this because you were mentioned. Reply to this email directly or view it on GitHub https://github.com/miyosuda/async_deep_reinforce/issues/1#issuecomment-216432243

joabim commented 8 years ago

With the correct ALE I also wind up at -17 at most after a days worth of training in pong. In breakout, the score maxes out at 2-3 for each episode

miyosuda commented 8 years ago

@aravindsrinivas

Did you mean -17 and -20/21?

Yes, sorry it means -17 and -20/21.

Also, what's the value of T you get after a day?

After 19 hour 20 min, global T was 10870045 (10.8 million) on my Core i7-6700 machine. (So 19.5 steps per sec for one thread)

By the way, I found ALE 0.5 has default setting of "repeat_action_probability=0.25." This parameter was introduced from ALE 0.5 and it causes poor performance.

https://groups.google.com/forum/#!topic/deep-q-learning/p4FAIaabwlo

So I disabled this with repeat_action_probability=0.0 in "ale.cfg".

ghost commented 8 years ago

@miyosuda

Isn't that equivalent to 43.2 mill frames? Which should be between 10 and 11 epochs.... Their graph shows that by 10 epochs, they are able to reach scores around +10. So, this is definitely not working..

miyosuda commented 8 years ago

@aravindsrinivas

Isn't that equivalent to 43.2 mill frames?

Yes it is 43.2 frames including skipped ones.

By the way, I'm still not sure I'm understanding what "epsilon" parameter you mentioned before in the comment means.

The decay parameter (called alpha in the paper) for RMSProp was 0.99 and the regularization constant (called epsilon in the paper) was 0.1.

Is this epsilon of RMSProp in equation (9) or the "beta" parameter in equation (7)? Is it possible to share original message that DeepMind author sent you?

I'll try easier task than pong to confirm correctness of my implementation.

zhuchiheng commented 8 years ago

Hi @aravindsrinivas, try these hyper parameters: T = 0 # Global shared counter TMAX = 5000000 # Max iteration of global shared counter THREADS = 8 # Number of running thread N_STEP = 5 # Number of steps before update WISHED_SCORE = 3000 # Stopper of iterative learning GAMMA = 0.99 # Decay rate of past observations deep q-learning ACTIONS = 1 # Number of valid actions STATES = 4 # Number of state ENTROPY_BETA = 0.001 # Entropy regulation term: beta, default: 0.001

INIT_LEARNING_RATE = 0.0001 # default: 1e-3

OPT_DECAY = 0.99 # Discouting factor for the gradient, default: 0.99 OPT_MOMENTUM = 0.0 # A scalar tensor, default: 0.0 OPT_EPSILON = 0.1 # 0.005 # value to avoid zero denominator, default: 0.01

ghost commented 8 years ago

@zhuchiheng Is this for a continuous world problem? Asking because you have only 1 action? I was actually talking with respect to Pong.

miyosuda commented 8 years ago

@aravindsrinivas There is another project trying a3c, and the result seems so much better than mine. https://github.com/muupan/async-rl

Please try his implementation and setting.