rmst / ddpg

TensorFlow implementation of the DDPG algorithm from the paper Continuous Control with Deep Reinforcement Learning (ICLR 2016)
MIT License
209 stars 64 forks source link

Why should terminated states removed from minibatch? #4

Closed leix28 closed 8 years ago

leix28 commented 8 years ago

In replay_memory.py

    indices = np.zeros(size,dtype=np.int)
    for k in range(size):
      # find random index 
      invalid = True
      while invalid:
        # sample index ignore wrapping over buffer
        i = random.randint(0, self.n-2)
        # if i-th sample is current one or is terminal: get new index
        if i != self.i and not self.terminals[i]:
          invalid = False

      indices[k] = i

This part removes some candidates.

Is indices=np.random.randint(0, self.n, size) more appropriate?

rmst commented 8 years ago

Hey, it only removes the transitions between two episodes. Transitions at the end of an episode are left untouched. Notice that in the end the terminals are obtained via self.terminals[indices+1].

Here the full method for reference:

def minibatch(self,size):
  # sample uniform random indexes
  indices = np.zeros(size,dtype=np.int)
  for k in range(size):
    # find random index 
    invalid = True
    while invalid:
      # sample index ignore wrapping over buffer
      i = random.randint(0, self.n-2)
      # if i-th sample is current one or is terminal: get new index
      if i != self.i and not self.terminals[i]:
        invalid = False

    indices[k] = i

  o = self.observations[indices,...]
  a = self.actions[indices]
  r = self.rewards[indices]
  o2 = self.observations[indices+1,...]
  t2 = self.terminals[indices+1]
  info = self.info[indices,...]

  return o, a, r, o2, t2, info
leix28 commented 8 years ago

I misunderstood some codes. Thank you very much!