torchmd / torchmd-net

Training neural network potentials
MIT License
335 stars 75 forks source link

Exception reloading model #267

Closed peastman closed 10 months ago

peastman commented 10 months ago

When trying to reload a saved model with load_model(), it's failing with the error

get_neighbor_pairs_kernel(str strategy, Tensor positions, Tensor batch, Tensor box_vectors, bool use_periodic, float cutoff_lower, float cutoff_upper, int max_num_pairs, bool loop, bool include_transpose) -> ((Tensor, Tensor, Tensor, Tensor)):
Expected a value of type 'float' for argument 'cutoff_lower' but instead found type 'int'.
:
  File "/home/peastman/workspace/torchmd-net/torchmdnet/models/utils.py", line 256
        if batch is None:
            batch = torch.zeros(pos.shape[0], dtype=torch.long, device=pos.device)
        edge_index, edge_vec, edge_weight, num_pairs = get_neighbor_pairs_kernel(
                                                       ~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
            strategy=self.strategy,
            positions=pos,

At line 261 of utils.py, if I change

            cutoff_lower=self.cutoff_lower,
            cutoff_upper=self.cutoff_upper,

to

            cutoff_lower=float(self.cutoff_lower),
            cutoff_upper=float(self.cutoff_upper),

the error goes away. I'm not sure whether that's the best place to fix it, but it seems a type conversion is needed somewhere.

peastman commented 10 months ago

The model in question is a TensorNet model with a ZBL prior. A similar model without the prior doesn't encounter the error, so it's probably connected to ZBL.

RaulPPelaez commented 10 months ago

train.py makes this conversion because it expects a float argument for those:

    parser.add_argument('--cutoff-lower', type=float, default=0.0, help='Lower cutoff in model')
    parser.add_argument('--cutoff-upper', type=float, default=5.0, help='Upper cutoff in model')

But calling something like:

model = load_model(..., cutoff_upper=5) # Should be 5.0

will have no type check at all. load_model has no way to know the types, so I think the fix should go either here: https://github.com/torchmd/torchmd-net/blob/af64cdb94769a5c8d26188d79e0baee9f3b75f1d/torchmdnet/models/model.py#L49-L50

Or by Typing the arguments in the individual models. I am going to go with model, since it is a less aggressive change.

Thanks Peter, you are testing TMDNet hard these days!