Closed peastman closed 10 months ago
The model in question is a TensorNet model with a ZBL prior. A similar model without the prior doesn't encounter the error, so it's probably connected to ZBL.
train.py makes this conversion because it expects a float argument for those:
parser.add_argument('--cutoff-lower', type=float, default=0.0, help='Lower cutoff in model')
parser.add_argument('--cutoff-upper', type=float, default=5.0, help='Upper cutoff in model')
But calling something like:
model = load_model(..., cutoff_upper=5) # Should be 5.0
will have no type check at all.
load_model
has no way to know the types, so I think the fix should go either here:
https://github.com/torchmd/torchmd-net/blob/af64cdb94769a5c8d26188d79e0baee9f3b75f1d/torchmdnet/models/model.py#L49-L50
Or by Typing the arguments in the individual models. I am going to go with model, since it is a less aggressive change.
Thanks Peter, you are testing TMDNet hard these days!
When trying to reload a saved model with
load_model()
, it's failing with the errorAt line 261 of utils.py, if I change
to
the error goes away. I'm not sure whether that's the best place to fix it, but it seems a type conversion is needed somewhere.