IntelLabs / bayesian-torch

A library for Bayesian neural network layers and uncertainty estimation in Deep Learning extending the core of PyTorch
BSD 3-Clause "New" or "Revised" License
483 stars 67 forks source link

Let the models return prediction only, saving KL Divergence as an attribute #9

Closed piEsposito closed 2 years ago

piEsposito commented 2 years ago

Closes #7 .

Let the user, if they want, to return predictions only on forward method, while saving kl divergence as an attribute. This is important to make it easier to integrate into PyTorch models.

Also, it does not break the lib as it is: we added a new parameter on forward method that defaults to True and, if manually set to false, returns predictions only.

Performed the following changes, on all layers:

That should help integrating with PyTorch experiments while keeping backward compatibility towards this lib.

ranganathkrishnan commented 2 years ago

@piEsposito Thanks! Can the KL computation in forward function skipped when return_kl is set to False? This can save compute cycles, when kl is not returned.

piEsposito commented 2 years ago

@ranganathkrishnan we can do that, thou I'm afraid that might pollute way too much. To avoid that, I'm thinking on the following changes:

  1. On every forward method, we add a new argument do_not_compute_kl that defaults to false.
  2. Set the kl computations to be assigned to the self.kl attribute
  3. If do_not_compute_kl is set to True, set self.kl to None, otherwise compute the kl divergence
  4. Return the self.kl attribute if return_kl is set to True

Something like

def forward(self, x, return_kl=True, do_not_compute_kl):
    ...
    if do_not_compute_kl:
        self.kl = None
    else:
        self.kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu,
                         self.prior_weight_sigma) 
    ...
    # returning outputs + perturbations
    if return_kl:
        return outputs + perturbed_outputs, self.kl
    return outputs + perturbed_outputs

What do you think? If you agree with that I can implement it that way and add it to this PR.

ranganathkrishnan commented 2 years ago

@piEsposito, I think it's better not to introduce additional condition flags. I will merge this PR as it is. Thank you for the contribution.