coarse-graining / cgnet

learning coarse-grained force fields
BSD 3-Clause "New" or "Revised" License
57 stars 26 forks source link

axis vs. dim in pytorch functions #181

Closed nec4 closed 4 years ago

nec4 commented 4 years ago

Heyo! I found a little bug in the simulation code, but it seems to only affect users of pytorch < 1.2:

Produced 100 initial coordinates.
Generating 100 simulations of length 1000000 saved at 10-step intervals (Fri Jun 26 23:42:01 2020)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-14-2b995afe27a7> in <module>
     26                  friction=friction, masses=masses)
     27 
---> 28 traj = sim.simulate()
     29 
     30 print(traj.shape)

<ipython-input-13-b79b4d3ebbf6> in simulate(self, overwrite)
    672             # save to arrays if relevant
    673             if (t+1) % self.save_interval == 0:
--> 674                 self._save_timepoint(x_new, v_new, forces, potential, t)
    675 
    676                 # save numpys if relevant; this can be indented here because

<ipython-input-13-b79b4d3ebbf6> in _save_timepoint(self, x_new, v_new, forces, potential, t)
    501         if v_new is not None:
    502             kes = 0.5 * torch.sum(torch.sum(self.masses[:, None]*v_new**2,
--> 503                                             axis=2), axis=1)
    504             self.kinetic_energies[save_ind, :] = kes
    505 

TypeError: sum() received an invalid combination of arguments - got (Tensor, axis=int), but expected one of:
 * (Tensor input)
 * (Tensor input, torch.dtype dtype)
      didn't match because some of the keywords were incorrect: axis
 * (Tensor input, tuple of ints dim, torch.dtype dtype, Tensor out)
 * (Tensor input, tuple of ints dim, bool keepdim, torch.dtype dtype, Tensor out)
 * (Tensor input, tuple of ints dim, bool keepdim, Tensor out)

This seems to arise from using axis instead of dim in torch.sum(). According to the most current docs for torch.sum, the correct keyword is dim. When working with pytorch 1.1 and changing axis to dim, the above error is not raised and the code seems to run fine. My guess is that pytorch >= 1.2 included some more interoperability with common numpy keywords/syntax - but I will check to make sure. I think this is something to consider fixing for the pytorch 1.1 branch. Let me know what you think!

nec4 commented 4 years ago

closing after merging #182 .