choderalab / modelforge

Infrastructure to implement and train NNPs
https://modelforge.readthedocs.io/en/latest/
MIT License
11 stars 4 forks source link

Implement adding conformers up to a maximum number of edges in training dataloader #46

Closed ArnNag closed 5 months ago

ArnNag commented 11 months ago

Description

The intermediate layers' edgewise features usually account for most of the GPU memory usage of a batch, based on my previous experience trying to train SAKE. This PR implements a dynamically batching dataloader that computes the number of edges in the upcoming conformer and adds conformers to the batch until a maximum number of edges is reached.

Todos

Questions

Status

codecov-commenter commented 10 months ago

Codecov Report

Merging #46 (5f7f248) into main (c2dbf18) will increase coverage by 0.18%. The diff coverage is 84.21%.

Additional details and impacted files
wiederm commented 8 months ago

is this PR ready for review?

ArnNag commented 8 months ago

A question I have is about how to integrate this into the PyTorch Lightning Dataloader API. I have currently implemented the maximum edge dataloader in a separate function here.

There are currently two issues with overriding the existing train_dataloader function:

  1. This breaks current expectations about the shape of batched data. This is because I introduced a custom batch_sampler argument that behaves differently from the behavior when using the batch_size argument to DataLoader. The dataset generation for test cases depends on these expectations.
  2. There are two more hyperparameters necessary to batch this way: the maximum number of edges and Euclidean cutoff we're using to filter the edges. The first issue needs to be resolved in order to determine how to supply these hyperparameters. (One option is to make them default to float('inf') when we want to use all edges, but we would have reimplement downstream tests to work with the shapes that this outputs.)
wiederm commented 8 months ago

Can we hold off with this one for now and concentrate on implementing SAKE?