borchero / pycave

Traditional Machine Learning Models for Large-Scale Datasets in PyTorch.
https://pycave.borchero.com
MIT License
122 stars 12 forks source link

Enable compatibility with torchmetrics >= 0.6.0 #27

Closed marcovarrone closed 1 year ago

marcovarrone commented 1 year ago

First of all, amazing project!

  1. I replaced the torchmetrics class AverageMeter with MeanMetric, which is essentially the same.
  2. I added the variable full_state_update = False at every PyCave's custom metric as suggested here. I set all of them to False as it doesn't seem to influence the results of the tests.
  3. Within KmeansPlusPlusInitLightningModule.nonparametric_training_epoch_end it was giving me errors because sometimes the choice variable is an empty tensor, and indexing a zero-dimensional tensor is not allowed since pytorch>=0.5. Is it supposed to be empty? Is there something breaking that is not captured by the tests? I added a simple if there, but if you have a more elegant solution let me know.

I run the test on both pytorch-lightning=1.6.0, torchmetrics=0.6.0 (minimum required) and pytorch_lightning=1.8.1, torchmetrics=0.10.2 (the most updated so far) and everything seem to work well.

Have a nice day!