Closed harnvo closed 2 years ago
在policy.maven.MAVEN里面, cuda部分的代码长这个样子...
policy.maven.MAVEN
if self.args.cuda: self.z_policy.cuda() self.eval_rnn.cuda() self.target_rnn.cuda() self.eval_qmix_net.cuda() self.target_qmix_net.cuda() self.mi_net.cuda()
然而,cuda()只是一个会返回在GPU的tensor的函数,你还需要让x = x.cuda()才能让x变成在GPU的tensor。 #92
cuda()
x = x.cuda()
never mind, 是我阅读PyTorch 文档不仔细,抱歉...
在
policy.maven.MAVEN
里面, cuda部分的代码长这个样子...然而,
cuda()
只是一个会返回在GPU的tensor的函数,你还需要让x = x.cuda()
才能让x变成在GPU的tensor。 #92