Closed genec1 closed 4 years ago
Hi! We are working on improving and polishing the code for creating the molecule datasets. Right after that we will do an example of how to use our model. It will take us a few days.
We made a clearer example of how to use our code in EXAMPLE.ipynb notebook. In this notebook you can see how to run our code up to the training loop.
Moreover, we cleaned the code responsible for features generation so it is now clear and easy to follow.
@genec1 does this solve your issue?
Looks good! I will try it out in a few days and report back. Appreciate the quick response!
Closing for now. But please reach out if you have any questions!
Thanks, kudkudak. The updated EXAMPLE.ipyb is much clearer. I'm still missing something, though, as I'm not getting any convergence with training. Here is my simple attempt at training:
model.cuda()
loss_fn = nn.MSELoss()
device = torch.device("cuda")
train_losses = []
model.train()
for epoch in range(0, 50):
for batch in data_loader:
model.zero_grad()
adjacency_matrix, node_features, distance_matrix, y = batch
batch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0
output = model(node_features, batch_mask, adjacency_matrix, distance_matrix, None)
loss = loss_fn(torch.flatten(output), torch.flatten(y.to(device)))
train_losses.append( loss.item() )
loss.backward()
# print(".", end="", flush=True)
# print(epoch)
#print("*")
plt.figure(figsize=(10,5))
plt.plot(train_losses)
I don't see any optimizer in your code snippet, this might be the issue.
I added an optimizer and a scheduler and am seeing a decrease in loss now. Thanks for pointing out my oversight!
Here is the log10(loss) over batch without any hyperparameter optimization:
Closing for now, please reopen if you have any other questions.
Hello! I'm trying to do a test run of the pre-trained MAT model. I've run SMILES through the DataLoader, which returns the following nine variables:
The model itself, as instantiated in
load_weights.ipynb
, is classGraphTransformer
which needs these arguments:Is there example code that shows how to go from the output of the DataLoader into the transformer?