ur-whitelab / hoomd-tf

A plugin that allows the use of Tensorflow in Hoomd-Blue for GPU-accelerated ML+MD
https://hoomd-tf.readthedocs.io
MIT License
30 stars 8 forks source link

Make atom selection more robust in iter_from_trajectory() #253

Closed hgandhi2411 closed 3 years ago

hgandhi2411 commented 3 years ago

@mehradans92 noticed that iter_from_traj throws an error when the system is big.

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-15-4dc65c7bc98d> in <module>
      7 losses = []
      8 mapped_rdfs = []
----> 9 for inputs, ts in htf.iter_from_trajectory(32, u, selection='all', r_cut=r_cut, period =50):
     10     nlist = inputs[0]
     11     positions = inputs[1]
~/.conda/envs/hoomd-tf2/lib/python3.7/site-packages/hoomd/htf/utils.py in iter_from_trajectory(nneighbor_cutoff, universe, selection, r_cut, period)
    287     # box_size = [box[0], box[1], box[2]]
    288     nlist = compute_nlist(atom_group.positions, r_cut=r_cut,
--> 289                           NN=nneighbor_cutoff, box_size=[box[0], box[1], box[2]])
    290     # if selection != 'all':
    291     #     universe = mda.Merge(universe.select_atoms(selection))
~/.conda/envs/hoomd-tf2/lib/python3.7/site-packages/hoomd/htf/utils.py in compute_nlist(positions, r_cut, NN, box_size, sorted, return_types)
    393     flat_idx = tf.concat([idx, tf.reshape(topk.indices, [-1, 1])], -1)
    394     # mask is reapplied here, so those huge numbers won't still be in there.
--> 395     nlist_pos = tf.reshape(tf.gather_nd(dist_mat, flat_idx), [-1, NN, 3])
    396     nlist_mask = tf.reshape(tf.gather_nd(mask_cast, flat_idx), [-1, NN, 1])
    397 
~/.conda/envs/hoomd-tf2/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py in wrapper(*args, **kwargs)
    199     """Call target, and fall back on dispatchers if there is a TypeError."""
    200     try:
--> 201       return target(*args, **kwargs)
    202     except (TypeError, ValueError):
    203       # Note: convert_to_eager_tensor currently raises a ValueError, not a
~/.conda/envs/hoomd-tf2/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py in gather_nd_v2(params, indices, batch_dims, name)
   5004 @dispatch.add_dispatch_support
   5005 def gather_nd_v2(params, indices, batch_dims=0, name=None):
-> 5006   return gather_nd(params, indices, name=name, batch_dims=batch_dims)
   5007 
   5008 
~/.conda/envs/hoomd-tf2/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py in wrapper(*args, **kwargs)
    199     """Call target, and fall back on dispatchers if there is a TypeError."""
    200     try:
--> 201       return target(*args, **kwargs)
    202     except (TypeError, ValueError):
    203       # Note: convert_to_eager_tensor currently raises a ValueError, not a
~/.conda/envs/hoomd-tf2/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py in gather_nd(params, indices, name, batch_dims)
   4996       return params.gather_nd(indices, name=name)
   4997     except AttributeError:
-> 4998       return gen_array_ops.gather_nd(params, indices, name=name)
   4999   else:
   5000     return batch_gather_nd(params, indices, batch_dims=batch_dims, name=name)
~/.conda/envs/hoomd-tf2/lib/python3.7/site-packages/tensorflow/python/ops/gen_array_ops.py in gather_nd(params, indices, name)
   3750       return _result
   3751     except _core._NotOkStatusException as e:
-> 3752       _ops.raise_from_not_ok_status(e, name)
   3753     except _core._FallbackException:
   3754       pass
~/.conda/envs/hoomd-tf2/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)
   6841   message = e.message + (" name: " + name if name is not None else "")
   6842   # pylint: disable=protected-access
-> 6843   six.raise_from(core._status_to_exception(e.code, message), None)
   6844   # pylint: enable=protected-access
   6845 
~/.conda/envs/hoomd-tf2/lib/python3.7/site-packages/six.py in raise_from(value, from_value)
InvalidArgumentError: params.NumElements() too large for int32 indexing: 6487308012 > 2147483647 [Op:GatherNd]
whitead commented 3 years ago

You can either select a smaller number of particles using selection keyword or reduce the number of neighbors. Or use a GPU with more memory.