Kaixhin / Rainbow

Rainbow: Combining Improvements in Deep Reinforcement Learning
MIT License
1.56k stars 282 forks source link

Improve the replay memory to avoid storing the autograd graphs in it #33

Closed deepbrain closed 5 years ago

deepbrain commented 5 years ago

I was profiling this code and discovered that the line mem.update_priorities(idxs, loss.detach()) was causing huge memory leakage and slowdowns with the current version of pytorch. Once I switched the replay memory to use numpy arrays instead of pytorch tensors, the memory usage dropped 3x and I saw overall speedups of 2-3x in training speed compared with the original version. So, it looks like the current code is attaching something big, possibly the entire graph (regardless of the detach method), to every node in the replay memory tree and this simple change fixes it.

-Art

Kaixhin commented 5 years ago

Crazy - thanks for tracking this down and submitting a fix! I'm pretty sure that .detach() is supposed to cut off any part of the graph, but it seems there must be some sort of tracking going on behind the hood, so I think your fix is what would be needed.

As for the sample method, would you be able to see if the no_grad() context manager has the same effect? Like so:

  def sample(self, batch_size):
    with torch.no_grad():
      p_total = self.transitions.total()
      ...
      weights = weights / weights.max()
    return tree_idxs, states, actions, returns, next_states, nonterminals, weights
deepbrain commented 5 years ago

I tested with torch.no_grad() in sample() method - a few findings:

probs = torch.tensor(probs, dtype=torch.float32, device=self.device)/p_total

generates an error if I run on a cuda device: RuntimeError: Expected object of type torch.cuda.FloatTensor but found type torch.FloatTensor for argument #2 'other'

so, I replaced it with:

probs = torch.tensor(probs, dtype=torch.float32, device=self.device)/torch.tensor(p_total, dtype=torch.float32, device=self.device) # Calculate normalised probabilities

According to this:

https://pytorch.org/blog/pytorch-0_4_0-migration-guide/

the detach() method returns an autograd compatible tensor with requires_grad=False, so if we store this tensor in the replay buffer, it will preserve all of the tensors in the loss computation graph in case if they are needed for the autograd backward later, which does not actually happen in our case, but pytorch does not know about this.

Kaixhin commented 5 years ago

Thanks a lot for testing this - seems like in general I should be looking to use numpy arrays in the experience replay memory to prevent any PyTorch tracking overhead, and convert them to tensors as late as possible. I'll merge this now but try go through the memory this week and shift more into numpy.