NVlabs / GA3C

Hybrid CPU/GPU implementation of the A3C algorithm for deep reinforcement learning.
BSD 3-Clause "New" or "Revised" License
652 stars 195 forks source link

pyTorch #19

Open dylanthomas opened 7 years ago

dylanthomas commented 7 years ago

Isn't there any plan on the horizon to port this code to pyTorch ?

ifrosio commented 7 years ago

We are not planning implementing it for now, but some people are indeed suggesting that pyTorch may be faster than TF. It would be great if someone can implement GA3C in pyTorch following our guidelines.

etienne87 commented 7 years ago

I did a quick trial in one of my branches . Actually, TF is almost twice as fast, because the naive way I did the vectorized loss is probably involving a lot of function calls. The same issue arises for Chainer version. The loss takes almost more time to compute than the cnn. I think it could work faster if implementing it as a specific layer.

ppwwyyxx commented 7 years ago

Just FYI, my friend was able to reproduce both the speed and performance of my a3c implementation with his pytorch code. It batches data differently from GA3C, but the overall structure is similar.

etienne87 commented 7 years ago

interesting @ppwwyyxx ! My naive implementation gives something like this :

results txt

I am not sure if the problem is in the batching, rather than the explicit calls & many steps of computation for the loss.

        p, v = self.model.forward_multistep(x_, c, h)
        probs = F.softmax(p)
        probs = F.relu(probs - Config.LOG_EPSILON)
        log_probs = torch.log(probs) 
        adv = (rewards - v)
        adv = torch.masked_select(adv,mask)
        log_probs_a = torch.masked_select(log_probs,a) #we cannot use it because of variable length input
        piloss = -torch.sum( log_probs_a * Variable(adv.data), 0)  
        entropy = torch.sum(torch.sum(log_probs*probs,1),0) * self.beta
        vloss = torch.sum(adv.pow(2),0) / 2
        loss = piloss + entropy + vloss

If someone knows how to do this more quickly in pytorch ...?

dylanthomas commented 7 years ago

@ppwwyyxx Is there a public git repo for your friend's pyTorch implementation ?

ppwwyyxx commented 7 years ago

Unfortunately no..