Closed MaddyThakker closed 3 years ago
The reason for this is predicted
is Tensor of shape torch.Size([10])
but classes
is a Tensor of shape torch.Size([10, 1])
. Need to reshape the classes to same size as of predicted to make it work.
Resolved with https://github.com/traja-team/traja/pull/66
File
test_models.py
implementstest_ae_classification_network_converges
(line 334) which during validationtrainer.validate(data_loaders['train_loader'])
(line 406) gives accuracy > 1.