Open dylanthomas opened 7 years ago
We are not planning implementing it for now, but some people are indeed suggesting that pyTorch may be faster than TF. It would be great if someone can implement GA3C in pyTorch following our guidelines.
I did a quick trial in one of my branches . Actually, TF is almost twice as fast, because the naive way I did the vectorized loss is probably involving a lot of function calls. The same issue arises for Chainer version. The loss takes almost more time to compute than the cnn. I think it could work faster if implementing it as a specific layer.
Just FYI, my friend was able to reproduce both the speed and performance of my a3c implementation with his pytorch code. It batches data differently from GA3C, but the overall structure is similar.
interesting @ppwwyyxx ! My naive implementation gives something like this :
I am not sure if the problem is in the batching, rather than the explicit calls & many steps of computation for the loss.
p, v = self.model.forward_multistep(x_, c, h)
probs = F.softmax(p)
probs = F.relu(probs - Config.LOG_EPSILON)
log_probs = torch.log(probs)
adv = (rewards - v)
adv = torch.masked_select(adv,mask)
log_probs_a = torch.masked_select(log_probs,a) #we cannot use it because of variable length input
piloss = -torch.sum( log_probs_a * Variable(adv.data), 0)
entropy = torch.sum(torch.sum(log_probs*probs,1),0) * self.beta
vloss = torch.sum(adv.pow(2),0) / 2
loss = piloss + entropy + vloss
If someone knows how to do this more quickly in pytorch ...?
@ppwwyyxx Is there a public git repo for your friend's pyTorch implementation ?
Unfortunately no..
Isn't there any plan on the horizon to port this code to pyTorch ?