NSAPH-Projects / topological-equivariant-networks

E(n)-Equivariant Topological Neural Networks
MIT License
19 stars 0 forks source link

Dataset loading/preprocessing to be fixed #23

Closed gdasoulas closed 4 months ago

gdasoulas commented 6 months ago

The CombinatorialComplexTransform function should be part of the pre_transform argument for QM9. With the current implementation, every time we need to preprocess the data.

    dataset = QM9(root=data_root)
    dataset = dataset.shuffle()
    dataset = dataset[: args.num_samples]
    dataset = [transform(sample) for sample in tqdm(dataset)]
ekarais commented 5 months ago

I recently tried implementing this, but it turns out this is very difficult for the following reason: If we make CombinatorialComplexTransform the pre_transform, then PyG's QM9 class tries to collate all samples after the pre_transform. However, it uses PyG's default collate function, with no option to pass our own. That collate function is incompatible with the transformed data, so this does not work. I also tried subclassing PyG's QM9 class to overwrite the collate function, but that also did not work.

However, I am also bothered by having to preprocess the data every time. It forces us to wait 15 minutes to see the effects of changes we make during debugging. If anyone wants to try solving this, I am happy to help. @gdasoulas @mauriciogtec @clabat9

ekarais commented 5 months ago

From #22

As mentioned in https://github.com/NSAPH-Projects/topological-equivariant-networks/issues/6, training with Rips-Vietoris uses too much memory (>32 GB), especially when setting dis to large values. Because pyg's QM9 is an InMemoryDataset, setting dis to a large value essentially means holding all 3-combinations of atoms of each molecule in memory, which explodes.

The best solution I see is to implement the QM9 dataset ourselves such that it isn't an InMemoryDataset. The transformed dataset would be stored on disk and the dataloader would access each batch from disk.

It is sadly not an option to apply the Rips-Vietoris transform on the fly (before batching). Because it is an expensive algorithm that is applied to each molecule separately, transforming on the fly would increase the runtime of each epoch from 1.5 minutes to 2 hours, meaning it would take over a month to complete 1000 epochs.

clabat9 commented 5 months ago

I would not spend more than some hour on this problem. We can always put the EMPSN table in the results saying that their procedure would cause OOM. What is the distance threshold they use?

ekarais commented 5 months ago

I think they use dis=4. I think that the bigger pain is having to wait 15+ minutes before each training. This makes debugging take way longer.

clabat9 commented 5 months ago

Agree, let's not worry about it, it is really a minor

ekarais commented 4 months ago

Addressed in #45.