traja-team / traja

Python tools for spatial trajectory and time-series data analysis
https://traja.readthedocs.io
MIT License
98 stars 25 forks source link

test_ae_classification_network_converges gives accuracy > 1 #57

Closed MaddyThakker closed 3 years ago

MaddyThakker commented 3 years ago

File test_models.py implements test_ae_classification_network_converges (line 334) which during validation trainer.validate(data_loaders['train_loader']) (line 406) gives accuracy > 1.

MaddyThakker commented 3 years ago

The reason for this is predicted is Tensor of shape torch.Size([10]) but classes is a Tensor of shape torch.Size([10, 1]). Need to reshape the classes to same size as of predicted to make it work.

MaddyThakker commented 3 years ago

Resolved with https://github.com/traja-team/traja/pull/66