benedekrozemberczki / SimGNN

A PyTorch implementation of "SimGNN: A Neural Network Approach to Fast Graph Similarity Computation" (WSDM 2019).
GNU General Public License v3.0
759 stars 147 forks source link

Batch data is not processed in parallel #15

Closed kxhit closed 5 years ago

kxhit commented 5 years ago

Hi! @benedekrozemberczki Thanks for your implementation of SimGNN with pytorch. I noticed that in process_batch() function, every sample(one graph_pair) in one batch is processed one by one instead of in parallel. I think this is inefficient. Do you have any idea to improve this? Thanks!!!

def process_batch(self, batch): """ Forward pass with a batch of data. :param batch: Batch of graph pair locations. :return loss: Loss on the batch. """ self.optimizer.zero_grad() losses = 0 for graph_pair in batch: data = process_pair(graph_pair) data = self.transfer_to_torch(data) target = data["target"] prediction = self.model(data) losses = losses + torch.nn.functional.mse_loss(data["target"], prediction) losses.backward(retain_graph = True) self.optimizer.step() loss = losses.item() return loss

benedekrozemberczki commented 5 years ago

Yes. You can use block diagonal batching.