pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.95k stars 3.61k forks source link

Optimizing the graph instead of the model parameters #464

Closed EoAelinr closed 5 years ago

EoAelinr commented 5 years ago

❓ Questions & Help

Hi.

I am trying to optimize a graph (A) nodes according to the features of another (B) (the goal is to make the features correlate). Both A and B are loaded as Batch object with FAUST() and DataLoader() My network is as follow :

Network(
  (conv1): SplineConv(1, 32)
  (conv2): SplineConv(32, 64)
  (conv3): SplineConv(64, 64)
  (conv4): SplineConv(64, 64)
  (conv5): SplineConv(64, 64)
  (loss_5):Loss()
  (conv6): SplineConv(64, 64)
  pooling
  (lin1): Linear(in_features=64, out_features=256, bias=True)
  (lin2): Linear(in_features=256, out_features=10, bias=True)
)

Its weights were obtained by training it on classifying FAUST poses. The loss computed in the Loss Module (which has access to precomputed B's features).

My loss is non-zero (e.g. 643.201843), but A.x does not change at all. What might I be doing wrong ?

My optimization loop is as follow :

optimizer = torch.optim.Adam([A.x])
run = [0]
while run[0] <= num_steps:

    def closure():
        optimizer.zero_grad()
        model(A)

        loss = model.loss
        loss.backward()

        run[0] += 1
        if run[0] % 50 == 0:
            print("run {}:".format(run))
            print('Loss : {:4f}'.format(loss))
            print()

        return loss

    optimizer.step(closure)
rusty1s commented 5 years ago

I have never seen a closure used for PyTorch optimization, but this should nonetheless work. Are A.x actually parameters? You should check the gradient flow.

EoAelinr commented 5 years ago

Thank you for your answer. You are right, A.x wasn't a parameters, adding A.x.requiresgrad(True) in closure() allowed it to get optimized.

The features that I want to optimize are the xyz coordinates associated to each node (the graph is a mesh). I am although unsure on how to update those according to the optimized A.x. Do you think I could try to update A.edge_attr instead, or that it might produce unfeasible meshes ? (e.g. with triangles with edges lengths a,b,c where a > b + c)

rusty1s commented 5 years ago

I wouldn't try to optimize A when working with meshes. You will get corrupted meshes for sure. You should simply try to optimize xyz based on mesh connectivity.

EoAelinr commented 5 years ago

I had missed that the gradient could flow up to it Optimizing A.pos worked. Thank you.