izmailovpavel / understandingbdl

Other
231 stars 38 forks source link

hmc example not runable #16

Open luningsun opened 2 years ago

luningsun commented 2 years ago

Hi, guys.

I am very interested in the hmc example you show in the paper. However, when I run it I got an error ''Scope' object is not iterable'. And it seems an error from PyTorch. I tried a different version of PyTorch but the error still exists. Below is the detailed error message. I wonder if you have come across similar questions and if you could help me solve them.


TypeError Traceback (most recent call last)

in () 41 num_samples=n_samples_per_chain, step_size=step_size, 42 num_steps_per_sample=num_steps_per_sample, tau_out=tau_out, ---> 43 tau_list=tau_list 44 ) 45 params_hmc = params_hmc[::num_steps_per_sample] /home/luningsun/anaconda3/envs/python36/lib/python3.6/site-packages/hamiltorch/samplers.py in sample_model(model, x, y, params_init, model_loss, num_samples, num_steps_per_sample, step_size, burn, inv_mass, jitter, normalizing_const, softabs_const, explicit_binding_const, fixed_point_threshold, fixed_point_max_iterations, jitter_max_tries, sampler, integrator, metric, debug, tau_out, tau_list, store_on_GPU, desired_accept_rate) 1327 torch.cuda.empty_cache() 1328 -> 1329 return sample(log_prob_func, params_init, num_samples=num_samples, num_steps_per_sample=num_steps_per_sample, step_size=step_size, burn=burn, jitter=jitter, inv_mass=inv_mass, normalizing_const=normalizing_const, softabs_const=softabs_const, explicit_binding_const=explicit_binding_const, fixed_point_threshold=fixed_point_threshold, fixed_point_max_iterations=fixed_point_max_iterations, jitter_max_tries=jitter_max_tries, sampler=sampler, integrator=integrator, metric=metric, debug=debug, desired_accept_rate=desired_accept_rate, store_on_GPU = store_on_GPU) 1330 1331 def sample_split_model(model, train_loader, params_init, num_splits, model_loss='multi_class_linear_output', num_samples=10, num_steps_per_sample=10, step_size=0.1, burn=0, inv_mass=None, jitter=None, normalizing_const=1., softabs_const=None, explicit_binding_const=100, fixed_point_threshold=1e-5, fixed_point_max_iterations=1000, jitter_max_tries=10, sampler=Sampler.HMC, integrator=Integrator.SPLITTING, metric=Metric.HESSIAN, debug=False, tau_out=1.,tau_list=None, store_on_GPU = True, desired_accept_rate=0.8): /home/luningsun/anaconda3/envs/python36/lib/python3.6/site-packages/hamiltorch/samplers.py in sample(log_prob_func, params_init, num_samples, num_steps_per_sample, step_size, burn, jitter, inv_mass, normalizing_const, softabs_const, explicit_binding_const, fixed_point_threshold, fixed_point_max_iterations, jitter_max_tries, sampler, integrator, metric, debug, desired_accept_rate, store_on_GPU) 941 momentum = gibbs(params, sampler=sampler, log_prob_func=log_prob_func, jitter=jitter, normalizing_const=normalizing_const, softabs_const=softabs_const, metric=metric, mass=mass) 942 --> 943 ham = hamiltonian(params, momentum, log_prob_func, jitter=jitter, softabs_const=softabs_const, explicit_binding_const=explicit_binding_const, normalizing_const=normalizing_const, sampler=sampler, integrator=integrator, metric=metric, inv_mass=inv_mass) 944 945 leapfrog_params, leapfrog_momenta = leapfrog(params, momentum, log_prob_func, sampler=sampler, integrator=integrator, steps=num_steps_per_sample, step_size=step_size, inv_mass=inv_mass, jitter=jitter, jitter_max_tries=jitter_max_tries, fixed_point_threshold=fixed_point_threshold, fixed_point_max_iterations=fixed_point_max_iterations, normalizing_const=normalizing_const, softabs_const=softabs_const, explicit_binding_const=explicit_binding_const, metric=metric, store_on_GPU = store_on_GPU, debug=debug) /home/luningsun/anaconda3/envs/python36/lib/python3.6/site-packages/hamiltorch/samplers.py in hamiltonian(params, momentum, log_prob_func, jitter, normalizing_const, softabs_const, explicit_binding_const, inv_mass, ham_func, sampler, integrator, metric) 759 if sampler == Sampler.HMC: 760 if type(log_prob_func) is not list: --> 761 log_prob = log_prob_func(params) 762 763 if util.has_nan_or_inf(log_prob): /home/luningsun/anaconda3/envs/python36/lib/python3.6/site-packages/hamiltorch/samplers.py in log_prob_func(params) 1137 1138 -> 1139 output = fmodel(x_device, params=params_unflattened) 1140 1141 if model_loss is 'binary_class_linear_output': /home/luningsun/anaconda3/envs/python36/lib/python3.6/site-packages/hamiltorch/util.py in fmodule(*args, **kwargs) 340 def fmodule(*args, **kwargs): 341 params_box[0] = kwargs.pop('params') # if key is in the dictionary, remove it and return its value, else return default. If default is not given and key is not in the dictionary, a KeyError is raised. --> 342 return fmodule_internal(*args, **kwargs) 343 344 return fmodule /home/luningsun/anaconda3/envs/python36/lib/python3.6/site-packages/hamiltorch/util.py in fmodule(*args, **kwargs) 329 # When running the kwargs no longer exist as they were put into params_box and therefore forward is just 330 # forward(self, x), so I could comment **kwargs out --> 331 return forward(self, *args) #, **kwargs) 332 333 return child_params_offset, fmodule /home/luningsun/anaconda3/envs/python36/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input) 138 for module in self: 139 input = module(input) --> 140 return input 141 142 TypeError: 'Scope' object is not iterable
izmailovpavel commented 2 years ago

Hi @luningsun, sorry for the late reply!

I looked into it, and I believe there was some change to hamiltorch, so I get the same error. I managed to get it to work by changing the network definition, to a more simple, explicit version:

class RegNet(nn.Sequential):
    def __init__(self):
        super(RegNet, self).__init__()
        self.l1 = torch.nn.Linear(2, 10)
        self.l2 = torch.nn.Linear(10, 10)
        self.l3 = torch.nn.Linear(10, 10)
        self.l4 = torch.nn.Linear(10, 10)
        self.l5 = torch.nn.Linear(10, 1)

    def forward(self, x):
        x = self.l1(x)
        x = torch.nn.functional.relu(x)
        x = self.l2(x)
        x = torch.nn.functional.relu(x)
        x = self.l3(x)
        x = torch.nn.functional.relu(x)
        x = self.l4(x)
        x = torch.nn.functional.relu(x)
        x = self.l5(x)
        return x

It may be worth starting an issue at the hamiltorch github (https://github.com/AdamCobb/hamiltorch).

If you don't mind using JAX, I also have an implementation of HMC for a similar problem here: https://github.com/google-research/google-research/blob/master/bnn_hmc/notebooks/synthetic_regression_inference.ipynb.