Closed deepbrain closed 6 years ago
Crazy - thanks for tracking this down and submitting a fix! I'm pretty sure that .detach()
is supposed to cut off any part of the graph, but it seems there must be some sort of tracking going on behind the hood, so I think your fix is what would be needed.
As for the sample
method, would you be able to see if the no_grad()
context manager has the same effect? Like so:
def sample(self, batch_size):
with torch.no_grad():
p_total = self.transitions.total()
...
weights = weights / weights.max()
return tree_idxs, states, actions, returns, next_states, nonterminals, weights
I tested with torch.no_grad() in sample() method - a few findings:
it still requires substantially (2x or more) more memory on my system to run than with the new numpy based code (seems like 15-20 gigabytes more)
it runs about 1.5 times slower than with the pure numpy structs in replay buffers
this line:
probs = torch.tensor(probs, dtype=torch.float32, device=self.device)/p_total
generates an error if I run on a cuda device: RuntimeError: Expected object of type torch.cuda.FloatTensor but found type torch.FloatTensor for argument #2 'other'
so, I replaced it with:
probs = torch.tensor(probs, dtype=torch.float32, device=self.device)/torch.tensor(p_total, dtype=torch.float32, device=self.device) # Calculate normalised probabilities
According to this:
https://pytorch.org/blog/pytorch-0_4_0-migration-guide/
the detach() method returns an autograd compatible tensor with requires_grad=False, so if we store this tensor in the replay buffer, it will preserve all of the tensors in the loss computation graph in case if they are needed for the autograd backward later, which does not actually happen in our case, but pytorch does not know about this.
Thanks a lot for testing this - seems like in general I should be looking to use numpy arrays in the experience replay memory to prevent any PyTorch tracking overhead, and convert them to tensors as late as possible. I'll merge this now but try go through the memory this week and shift more into numpy.
I was profiling this code and discovered that the line mem.update_priorities(idxs, loss.detach()) was causing huge memory leakage and slowdowns with the current version of pytorch. Once I switched the replay memory to use numpy arrays instead of pytorch tensors, the memory usage dropped 3x and I saw overall speedups of 2-3x in training speed compared with the original version. So, it looks like the current code is attaching something big, possibly the entire graph (regardless of the detach method), to every node in the replay memory tree and this simple change fixes it.
-Art