Open sansiro77 opened 3 years ago
Hello @sansiro77 , and thank you for all your current contributions to BLiTZ.
I think I did not understand your issue, can you give me some more details? What is the expected behavior? How is it behaving? As soon as I know dis, if there is a bug I can try fixing it.
Sorry for the late reply.
The default data type is torch.float32
in PyTorch.
Take "bayesian_LeNet_mnist.py" as an example. Print the model parameters, you'll get something like:
Iteration: 250 | Accuracy of the network on the 10000 test images: 93.3 %
fc3 weight mu rho [0,0] 0.07923417538404465 -6.868440628051758
Iteration: 500 | Accuracy of the network on the 10000 test images: 95.77 %
fc3 weight mu rho [0,0] 0.07503953576087952 -6.837374687194824
Iteration: 750 | Accuracy of the network on the 10000 test images: 96.22 %
fc3 weight mu rho [0,0] 0.06658764183521271 -6.785527229309082
Iteration: 1000 | Accuracy of the network on the 10000 test images: 97.2 %
fc3 weight mu rho [0,0] 0.06792515516281128 -6.734956741333008
So far so good.
However, when I modified the code as follows:
classifier = BayesianCNN().to(device).double()
loss = classifier.sample_elbo(inputs=datapoints.to(device).double(),
outputs = classifier(images.to(device).double())
The output became:
Iteration: 250 | Accuracy of the network on the 10000 test images: 93.57 %
fc3 weight mu rho [0,0] 0.05312253162264824 -6.983799934387207
Iteration: 500 | Accuracy of the network on the 10000 test images: 96.62 %
fc3 weight mu rho [0,0] 0.05312253162264824 -6.983799934387207
Iteration: 750 | Accuracy of the network on the 10000 test images: 96.66 %
fc3 weight mu rho [0,0] 0.05312253162264824 -6.983799934387207
Iteration: 1000 | Accuracy of the network on the 10000 test images: 97.07 %
fc3 weight mu rho [0,0] 0.05312253162264824 -6.983799934387207
And actually all printed parameters keep unchanged.
The model is still updated, but I think we cannot obtain the right parameters.
This problem also emerged when I used GPU even there was NO .double()
I found that if you use
.double()
to change the type of both model and data, the printed parameters will keep unchanged (although the model seems to be updated). Are there any explanations?