Closed kaka2nd closed 2 years ago
There are two different approaches to do this:
data
objects in your __getitem__
function and use a custom collate function:
def __getitem__(self, index):
return data1, data2
from torch_geometric.data import Batch
def custom_collate(data_list): batch_1 = Batch.from_data_list([d[0] for d in data_list]) batch_2 = Batch.from_data_list([d[1] for d in data_list]) return batch_1, batch_2
from torch.utils.data import DataLoader loader = DataLoader(dataset, ..., collate_fn=custom_collate)
for batch_1, batch_2 in loader: ...
2. Put both graph into *a single* data object
```python
def __getitem__(self, index):
data = Data(edge_index_1=..., edge_index_2=..., x_1=..., x_2=..., num_nodes=None)
and make use of the follow_batch
argument in the DataLoader:
from torch_geometric.data import DataLoader
loader = DataLoader(dataset, ..., follow_batch=['x_1', 'x_2'])
for data in loader:
...
Note that we explicitly set num_nodes
to None
due to the data
object holding two graphs. This will omit the creation of a data.batch
assignment vector. Instead, the follow_batch
argument will create two assignment vectors for each mini-batch named data.x_1_batch
and data.x_2_batch
respectively.
Edit: Another straightforward approach is to initialize two dataloaders:
from torch_geometric.data import DataLoader
loader_1 = DataLoader(dataset_1, ...)
loader_2 = DataLoader(dataset_2, ...)
for data_1, data_2 in zip(loader_1, loader_2):
...
Wow! Thanks for the detailed reply. Have a good day (or night).
Thanks for the straightforward approach mentioned in the Edit
. In this case, how is the common label (or target
) for the involved graphs specified?
One possible method to tackle this would be to keep the common label in both the data_1
and data_2
objects.
Yeah, that would be possible. With a more recent PyG release, it's also possible to do the following now:
class PairDataset(torch.utils.data.Dataset):
def __init__(self, dataset1, dataset2, label):
...
def __getitem__(self, idx):
return dataset1[idx], dataset2[idx], label[idx]
and use the torch_geometric.data.DataLoader
for batching:
loader = DataLoader(PairDataset(dataset1, dataset2, labels), batch_size=32, ...)
for data1, data2, label in loader:
...
Thanks for the quick reply. I tried the new approach and got the following error.
loader = DataLoader(PairDataset(dataset1, dataset2, labels), batch_size=32, ...)
for data1, data2, label in loader:
...
TypeError: object of type 'PairDataset' has no len()
Did I miss something?
You need to additionally implement the __len__
method (sorry for missing that):
def __len__(self):
return len(dataset1)
I use pytorch_geometric==1.4.0 and method 1 works for me. Really great answer!
❓ Questions & Help
Hi I am working on building siamese/ triplet model. In this case, I need to input pairs with labels (or triplet without labels). In original pytorch, an easy way is using data to output two imgaes like
Then I can use this to feed into models which accept two inputs like
output1,output2 = net(img0,img1)
So I was wondering how I could implement the siamese/ triplet model in your library, which was output 2 or 3 x from the data? Or any indirect idea I can do this ?
Hope for your reply. Thanks.