lab-cosmo / metatrain

Training and evaluating machine learning models for atomistic systems.
https://lab-cosmo.github.io/metatrain/
BSD 3-Clause "New" or "Revised" License
13 stars 3 forks source link

Implement `get_stats` for `Dataset` and print it before training #251

Closed frostedoyster closed 3 weeks ago

frostedoyster commented 3 weeks ago

This implements a get_stats() for Dataset. Closes #205.

Contributor (creator of pull-request) checklist


📚 Documentation preview 📚: https://metatrain--251.org.readthedocs.build/en/251/

frostedoyster commented 3 weeks ago

TODO: gradients (edit: done)

@PicoCentauri the main issue here is that often we don't use our Dataset but rather Subsets from torch that don't have our repr. If we want to keep the __repr__ idea, we will have to make our own Subset that inherits from the one in torch. Otherwise we can extract the current __repr__ as a standalone function that takes in a Dataset or Subset

frostedoyster commented 3 weeks ago

It's a bit cumbersome at the moment because having it as a method forces us to inherit from Subset from torch and also modify their train_test_split function that would otherwise return one of their Subsets, and not ours. These complications could be avoided if we made get_stats a standalone function