wjmaddox / online_vargp

Online variational GPs
GNU General Public License v3.0
29 stars 6 forks source link

Cannot reproduce `active_learning` experiment #2

Open st-- opened 2 years ago

st-- commented 2 years ago

When running

python qnIPV_experiment.py --num_init=10 --model=svgp --num_steps=250 --seed=1 --output=malaria_nipv_svgp_1.pt

as given in the README.md, I get the following exception:

Traceback (most recent call last):
  File "qnIPV_experiment.py", line 301, in <module>
    main(args)
  File "qnIPV_experiment.py", line 228, in main
    candidates, acq_value = optimize_acqf(
  File ".../anaconda3/envs/onlinet/lib/python3.8/site-packages/botorch/optim/optimize.py", line 150, in optimize_acqf
    batch_initial_conditions = ic_gen(
  File ".../anaconda3/envs/onlinet/lib/python3.8/site-packages/botorch/optim/initializers.py", line 112, in gen_batch_initial_conditions
    Y_rnd_curr = acq_function(
  File ".../anaconda3/envs/onlinet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File ".../anaconda3/envs/onlinet/lib/python3.8/site-packages/botorch/utils/transforms.py", line 255, in decorated
    return method(cls, X, **kwargs)
  File ".../anaconda3/envs/onlinet/lib/python3.8/site-packages/botorch/utils/transforms.py", line 214, in decorated
    output = method(acqf, X, *args, **kwargs)
  File ".../anaconda3/envs/onlinet/lib/python3.8/site-packages/botorch/acquisition/active_learning.py", line 92, in forward
    fantasy_model = self.model.fantasize(
  File ".../anaconda3/envs/onlinet/lib/python3.8/site-packages/botorch/models/model.py", line 140, in fantasize
    return self.condition_on_observations(X=X, Y=Y_fantasized, **kwargs)
  File ".../anaconda3/envs/onlinet/lib/python3.8/site-packages/botorch/models/gpytorch.py", line 196, in condition_on_observations
    return self.get_fantasy_model(inputs=X, targets=Y, **kwargs)
  File ".../online_vargp/volatilitygp/models/single_task_variational_gp.py", line 513, in get_fantasy_model
    return super().get_fantasy_model(
  File ".../online_vargp/volatilitygp/models/single_task_variational_gp.py", line 361, in get_fantasy_model
    fantasy_model = inducing_exact_model.condition_on_observations(
  File ".../anaconda3/envs/onlinet/lib/python3.8/site-packages/botorch/models/gpytorch.py", line 394, in condition_on_observations
    self._validate_tensor_args(X=X, Y=Y, Yvar=noise, strict=False)
  File ".../anaconda3/envs/onlinet/lib/python3.8/site-packages/botorch/models/gpytorch.py", line 94, in _validate_tensor_args
    raise BotorchTensorDimensionError(
botorch.exceptions.errors.BotorchTensorDimensionError: An explicit output dimension is required for observation noise. Expected Yvar with shape: torch.Size([6, 1]) (got torch.Size([5, 6])).

Calling the experiment script with --random, or running --model=exact (with or without --random) runs fine.

Might this have something to do with the svgp model getting init_y.view(-1) whereas the other models get init_y.view(-1, 1) (similarly for init_y_var)?

wjmaddox commented 2 years ago

Ugh, yes, that looks like some sort of improper dimension based error. I'll try to take a closer look this weekend.

Does it work if you enforce the dimension?

st-- commented 2 years ago

Thanks for responding! Could you clarify what you mean by "enforcing the dimension" (what & where)?

st-- commented 2 years ago

(Just to clarify, I used GPyTorch==1.4.0 and botorch==0.4.0, as the code doesn't run with recent versions due to a bunch of breaking changes in those libraries.)

wjmaddox commented 2 years ago

Ugh, yes, not surprising that it no longer works, especially as the core contribution (fantasization) here is merged into gpytorch for gaussian likelihoods, but not botorch (one day i'll publish that PR).

"enforcing the dimension"

off the top of my head, i mean this: "svgp model getting init_y.view(-1) whereas the other models get init_y.view(-1, 1)"

st-- commented 2 years ago

Well, changing the svgp constructor arguments from .view(-1) to provide either .view(-1, 1) or .view(1, -1) instead still doesn't work... would you be able to fix the script ?

wjmaddox commented 2 years ago

I pushed a fix and it ran locally for me. Sorry for the long delay there.