Hi! @benedekrozemberczki Thanks for your implementation of SimGNN with pytorch.
I noticed that in process_batch() function, every sample(one graph_pair) in one batch is processed one by one instead of in parallel. I think this is inefficient.
Do you have any idea to improve this? Thanks!!!
def process_batch(self, batch): """ Forward pass with a batch of data. :param batch: Batch of graph pair locations. :return loss: Loss on the batch. """ self.optimizer.zero_grad() losses = 0 for graph_pair in batch: data = process_pair(graph_pair) data = self.transfer_to_torch(data) target = data["target"] prediction = self.model(data) losses = losses + torch.nn.functional.mse_loss(data["target"], prediction) losses.backward(retain_graph = True) self.optimizer.step() loss = losses.item() return loss
Hi! @benedekrozemberczki Thanks for your implementation of SimGNN with pytorch. I noticed that in process_batch() function, every sample(one graph_pair) in one batch is processed one by one instead of in parallel. I think this is inefficient. Do you have any idea to improve this? Thanks!!!
def process_batch(self, batch): """ Forward pass with a batch of data. :param batch: Batch of graph pair locations. :return loss: Loss on the batch. """ self.optimizer.zero_grad() losses = 0 for graph_pair in batch: data = process_pair(graph_pair) data = self.transfer_to_torch(data) target = data["target"] prediction = self.model(data) losses = losses + torch.nn.functional.mse_loss(data["target"], prediction) losses.backward(retain_graph = True) self.optimizer.step() loss = losses.item() return loss