Hi !
I wanted to reproduce the results in espaloma and followed the training method suggested here in https://espaloma.wangyq.net/experiments/qm_fitting.html.
But the error occur when I wanted to calculate the training set performance :
Traceback ( most recent call last): File "train_gen2.py", line 86, in <module> u = torch.cat(u, dim=0) RuntimeError: Sizes of tensors must match except in dimension8. Expected size 42 but got size 21 for tensor number 1 in the list
I used the same code for training set performance and validation set performance as the one in https://espaloma.wangyq.net/experiments/qm_fitting.html :
`with torch.no_grad():
for idx_epoch in range(10000):
espaloma_model.load_state_dict(
torch.load("%s.th" % idx_epoch)
)
# training set performance
u = []
u_ref = []
for g in ds_tr:
if torch.cuda.is_available():
g.heterograph = g.heterograph.to("cuda:0")
espaloma_model(g.heterograph)
u.append(g.nodes['g'].data['u'])
u_ref.append(g.nodes['g'])
u = torch.cat(u, dim=0)
u_ref = torch.cat(u_ref, dim=0)
loss_tr.append(inspect_metric(u, u_ref))
# validation set performance
u = []
u_ref = []
for g in ds_vl:
if torch.cuda.is_available():
g.heterograph = g.heterograph.to("cuda:0")
espaloma_model(g.heterograph)
u.append(g.nodes['g'].data['u'])
u_ref.append(g.nodes['g'])
u = torch.cat(u, dim=0)
u_ref = torch.cat(u_ref, dim=0)
loss_vl.append(inspect_metric(u, u_ref))`
Hi ! I wanted to reproduce the results in espaloma and followed the training method suggested here in https://espaloma.wangyq.net/experiments/qm_fitting.html. But the error occur when I wanted to calculate the training set performance :
Traceback ( most recent call last): File "train_gen2.py", line 86, in <module> u = torch.cat(u, dim=0) RuntimeError: Sizes of tensors must match except in dimension8. Expected size 42 but got size 21 for tensor number 1 in the list
I used the same code for training set performance and validation set performance as the one in https://espaloma.wangyq.net/experiments/qm_fitting.html : `with torch.no_grad(): for idx_epoch in range(10000): espaloma_model.load_state_dict( torch.load("%s.th" % idx_epoch) )
Thanks for any help given.