yaringal / ConcreteDropout

Code for Concrete Dropout as presented in https://arxiv.org/abs/1705.07832
MIT License
245 stars 68 forks source link

pytorch version: small error? #14

Open hansweytjens opened 3 years ago

hansweytjens commented 3 years ago

Dear Yarin,

fascinating research, which I am now trying to use in my own. I believe there is a small error in the 'fit_model' function of the pytorch version concerning the computation of the batches, which affects execution speed even when probably benign to the results.

I believe this code corrects the error:

... for i in range(self.nb_epoch): for batch in range(int(np.ceil(self.X.shape[0] / self.batch_size))): _x = self.X[self.batch_size batch : self.batch_size (batch+1)]
_y = self.Y[self.batch_size batch : self.batch_size (batch+1)]
x = torch.FloatTensor(_x).cuda() # 32-bit floating point y = torch.FloatTensor(_y).cuda() mean, log_var, regularization = self.model(x) # forward pass loss = heteroscedastic_loss(y, mean, log_var) + regularization self.optimizer.zero_grad() loss.backward() self.optimizer.step() ...

Kind regards,

Hans