ardigen / MAT

The official implementation of the Molecule Attention Transformer.
MIT License
234 stars 57 forks source link

Assistance needed for running model #3

Closed genec1 closed 4 years ago

genec1 commented 4 years ago

Hello! I'm trying to do a test run of the pre-trained MAT model. I've run SMILES through the DataLoader, which returns the following nine variables:

adj afm bft orderAtt aromAtt conjAtt ringAtt distances label

The model itself, as instantiated in load_weights.ipynb, is class GraphTransformer which needs these arguments:

src src_mask adj_matrix distances_matrix edges_att

Is there example code that shows how to go from the output of the DataLoader into the transformer?

Mazzza commented 4 years ago

Hi! We are working on improving and polishing the code for creating the molecule datasets. Right after that we will do an example of how to use our model. It will take us a few days.

Mazzza commented 4 years ago

We made a clearer example of how to use our code in EXAMPLE.ipynb notebook. In this notebook you can see how to run our code up to the training loop.

Moreover, we cleaned the code responsible for features generation so it is now clear and easy to follow.

kudkudak commented 4 years ago

@genec1 does this solve your issue?

genec1 commented 4 years ago

Looks good! I will try it out in a few days and report back. Appreciate the quick response!

kudkudak commented 4 years ago

Closing for now. But please reach out if you have any questions!

genec1 commented 4 years ago

Thanks, kudkudak. The updated EXAMPLE.ipyb is much clearer. I'm still missing something, though, as I'm not getting any convergence with training. Here is my simple attempt at training:

model.cuda()
loss_fn = nn.MSELoss()
device = torch.device("cuda")

train_losses = []
model.train()

for epoch in range(0, 50):
    for batch in data_loader:
        model.zero_grad()
        adjacency_matrix, node_features, distance_matrix, y = batch
        batch_mask = torch.sum(torch.abs(node_features), dim=-1) != 0
        output = model(node_features, batch_mask, adjacency_matrix, distance_matrix, None)
        loss = loss_fn(torch.flatten(output), torch.flatten(y.to(device)))
        train_losses.append( loss.item() )
        loss.backward()
#        print(".", end="", flush=True)
#    print(epoch)

#print("*")    

plt.figure(figsize=(10,5))
plt.plot(train_losses)

Loss over batch

Mazzza commented 4 years ago

I don't see any optimizer in your code snippet, this might be the issue.

genec1 commented 4 years ago

I added an optimizer and a scheduler and am seeing a decrease in loss now. Thanks for pointing out my oversight!

Here is the log10(loss) over batch without any hyperparameter optimization:

Unknown

Mazzza commented 4 years ago

Closing for now, please reopen if you have any other questions.