Closed piEsposito closed 2 years ago
@piEsposito Thanks! Can the KL computation in forward function skipped when return_kl
is set to False
? This can save compute cycles, when kl is not returned.
@ranganathkrishnan we can do that, thou I'm afraid that might pollute way too much. To avoid that, I'm thinking on the following changes:
forward
method, we add a new argument do_not_compute_kl
that defaults to false.kl
computations to be assigned to the self.kl
attributedo_not_compute_kl
is set to True
, set self.kl
to None
, otherwise compute the kl divergenceself.kl
attribute if return_kl
is set to True
Something like
def forward(self, x, return_kl=True, do_not_compute_kl):
...
if do_not_compute_kl:
self.kl = None
else:
self.kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu,
self.prior_weight_sigma)
...
# returning outputs + perturbations
if return_kl:
return outputs + perturbed_outputs, self.kl
return outputs + perturbed_outputs
What do you think? If you agree with that I can implement it that way and add it to this PR.
@piEsposito, I think it's better not to introduce additional condition flags. I will merge this PR as it is. Thank you for the contribution.
Closes #7 .
Let the user, if they want, to return predictions only on forward method, while saving kl divergence as an attribute. This is important to make it easier to integrate into PyTorch models.
Also, it does not break the lib as it is: we added a new parameter on forward method that defaults to True and, if manually set to false, returns predictions only.
Performed the following changes, on all layers:
return_kl
on allforward
methods, defaulting toTrue
. If set to false, won't returnkl
.kl
attribute to each layer, updating it at every feedforward step. Useful when integrating with already-built PyTorch models.That should help integrating with PyTorch experiments while keeping backward compatibility towards this lib.