wjmaddox / swa_gaussian

Code repo for "A Simple Baseline for Bayesian Uncertainty in Deep Learning"
BSD 2-Clause "Simplified" License
451 stars 81 forks source link

Sampling using SWAG #30

Open ayhem18 opened 6 months ago

ayhem18 commented 6 months ago

First of all thank you for publicly sharing your work.

I am a senior CS bachelor student and I am using the SWAG estimate of the posterior P(theta | dataset) as part of a theoretical framework supporting the empirical results I reached so far during my thesis. I have a couple of questions concerning the SWAG class.

So if I understood correctly, the SWAG class provides a way to sample from the posterior and compute the log probability of the samples. After digging deeper into the code, I see that the output of the "compute_logprob" method depends on the values of "mean_list, var_list, covar_mat_root_list" generated by the "generate_mean_var_covar()" method as indicated in the code snippet below.

image

Going through "generate_mean_var_covar()" method, these values are extracted from the "mean" and "sq_mean" attributes of each sub-module as indicated in the code snippet below:

image

So in order to get different outputs for the "compute_logprob()" method, the values of "mean" and "sq_mean" in the different sub modules need to change. However, the only method that changes these values is the "collect_model()" method. Hence, I conjunctured that I should proceed as follows:

  1. Define a base model, then define a swag model with the same base class.
  2. when training the base model I should call the swag_model.collect_model() at the end of each epoch as this will update the parameters (the mean, and covariance matrix).
  3. After the training, the swag_model can be used to sample from the Posterior distribution as follows

I would greatly appreciate it if you can confirm / correct my understanding of your implementation. Thanks a lot in advance

wjmaddox commented 6 months ago

not sure i follow but yes that's what i'd expect the code to be doing

ayhem18 commented 6 months ago

Thank you for prompt response. To clarify, I would like to build upon your work and sample from the posterior distribution P(parameter | data). I proceed as follows:

  1. create a base model and a swag model separately. When training the base model, at the end of each epoch, I call swag_model.collect_model(base_model)

  2. after training, I proceed as follows: call: base_model.sample() // if I am not mistaken, it sets the parameters of the swag model to 'w' where w ~ P(param | data) call: base_model.compute_log_prob() // returns an estimate of log(P(w | data))

Would you please confirm / correct my understanding of the code. Thanks a lot in advance.