sbi-benchmark / sbibm

Simulation-based inference benchmark
https://sbi-benchmark.github.io
MIT License
88 stars 34 forks source link

First noreference posterior task #34

Open psteinb opened 2 years ago

psteinb commented 2 years ago

As documented in #19, this PR rebases on main and adopts #19 accordingly

Closes #18

psteinb commented 2 years ago

In the course of setting up this PR, I ran into this problem inside snpe:

self = <sbibm.tasks.noref_beam.task.NorefBeam object at 0x7f6b410389a0>
data = tensor([3.1400e+02, 0.0000e+00, 3.3500e+02, 3.0000e+00, 3.5800e+02, 1.0000e+00,
        3.8700e+02, 1.0000e+00, 4.2300...0, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 6.0198e+04])

    def flatten_data(self, data: torch.Tensor) -> torch.Tensor:
        """Flattens data

        Data returned by the simulator is always flattened into 2D Tensors
        """
>       return data.reshape(-1, self.dim_data)
E       RuntimeError: shape '[-1, 400]' is invalid for input of size 401
E         Trace Shapes:                
E          Param Sites:                
E         Sample Sites:                
E       parameters dist           1 | 4
E                 value           1 | 4
E             data dist 200 200   1 |  
E                 value       1 400 |

I have to find out where this comes from.

psteinb commented 2 years ago

This PR is almost done. However, when calling this simulator with snpe I get an error that is caused by some internal pyro magic. My pyro foo is not enough to disentangle what is going on.

This is the backtrace at the position where the error is raised:

(Pdb) bt
  /usr/lib64/python3.9/bdb.py(623)runcall()
-> res = func(*args, **kwds)
  /home/steinbac/development/sbibm/repo/tests/tasks/noref_beam/test_interface_noref_beam.py(106)test_benchmark_metrics_selfobserved_three()
-> outputs, nsim, logprob_truep = run_snpe(
  /home/steinbac/development/sbibm/repo/sbibm/algorithms/sbi/snpe.py(82)run()
-> transforms = task._get_transforms(automatic_transforms_enabled)["parameters"]
  /home/steinbac/development/sbibm/repo/sbibm/tasks/task.py(296)_get_transforms()
-> _, transforms = get_log_prob_fn(
  /home/steinbac/development/sbibm/repo/sbibm/utils/pyro.py(76)get_log_prob_fn()
-> max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs)
  /home/steinbac/development/sbibm/repo/sbibm/utils/pyro.py(375)_guess_max_plate_nesting()
-> model_trace = poutine.trace(model).get_trace(*args, **kwargs)
  /home/steinbac/development/sbibm/repo/sbibm-main-venv/lib64/python3.9/site-packages/pyro/poutine/trace_messenger.py(198)get_trace()
-> self(*args, **kwargs)
  /home/steinbac/development/sbibm/repo/sbibm-main-venv/lib64/python3.9/site-packages/pyro/poutine/trace_messenger.py(174)__call__()
-> ret = self.fn(*args, **kwargs)
  /home/steinbac/development/sbibm/repo/sbibm-main-venv/lib64/python3.9/site-packages/pyro/poutine/messenger.py(12)_context_wrap()
-> return fn(*args, **kwargs)
  /home/steinbac/development/sbibm/repo/sbibm/tasks/task.py(340)model_fn()
-> return simulator(prior_())
  /home/steinbac/development/sbibm/repo/sbibm/tasks/simulator.py(58)__call__()
-> data = self.simulator(parameters, **kwargs)
> /home/steinbac/development/sbibm/repo/sbibm/tasks/noref_beam/task.py(278)simulator()
-> len(samples.shape) >= 3

for some reason, bdist = pdist.Binomial(total_count=self.flood_samples, probs=img) does not return the expected result in this scenario. and the following pyro.sample call produces something that is missing a dimension. Could be a batch_shape versus event_shape issue, but I don't understand where it is coming from.

@jan-matthis if you have any idea, please let me know.

jan-matthis commented 2 years ago

The trouble you run into has to do with a call to transforms = task._get_transforms(automatic_transforms_enabled)["parameters"]. This method tries to automatically construct transformations into unbounded parameter space by inspecting the pyro model of the task. Since the output of the task is not equal to a sample from a pyro.sample-call but rather a reduced version of it (due to the torch.sum statements) this automatic construction cannot work.

I think we should be able to make this work out by 1) remove the assert statement you added, and 2) overriding the _get_transforms method in the task class, e.g. by:

    def _get_transforms(
        self,
        *args,
        **kwargs: Any,
    ) -> Dict[str, Any]:
        return {"parameters": torch.distributions.transforms.IndependentTransform(torch.distributions.transforms.identity_transform, 1) }

Note that this would simply define identity transforms. I would suggest that we rather manually specify transforms that turn the prior bounds into unbounded space by a combination of sigmoidal/affine transforms. The Gaussian Linear Uniform task has a similar prior (only with different bounds), to see which transforms are automatically constructed for it, you can take a look at them as follows:

import sbibm
task = sbibm.get_task("gaussian_linear_uniform")
transforms = task._get_transforms()

Being able to transform the problem into unbounded space can help performance with NLE/NRE -- so it would be nice to have appropriate transforms as part of the class for methods that do want to exploit them. Having said that, in principle the code should work with the identity transform as specified above.

From what I saw, you might encounter a problem with SNPE down the line in tests -- conditioning on multiple observations is currently only supported by NLE/NRE, as explained in the message that sbi raises.

I hope this helps!

psteinb commented 2 years ago

Thanks for the hints. For the record, I saw that only two_moons and slcp override _get_transforms.

Based on your comment, I added a test for SNLE (which is tricky within it's own right, as these MCMC based SNLE runs take time ... too long for unit test). As I keep getting out-of-support posteriors, like so

E           ValueError: Expected value argument (Tensor of shape (1, 4)) to be within the support (Interval(lower_bound=tensor([20., 20.,  5.,  5.]), upper_bound=tensor([80., 80., 15., 15.]))) of the distribution Uniform(low: torch.Size([4]), high: torch.Size([4])), but found invalid values:
E           tensor([[80.0028, 27.8328, 14.9613,  6.0025]])

I was wondering, if I could expand the noref_beam task in such a fashion, that it would only use summary statistics for my observations x_o (e.g. locations of quartiles) instead of the full projections of my multivariate normal.

Would I have to write a new task for this or can this be leveraged within a given task using the args/kwargs of the task constructor?

jan-matthis commented 2 years ago

Sure!

We should be able to fix the out-of-support problem by running MCMC in unconstrained space. For this, we would need to implement transformations to unconstrained space, e.g. a combination of affine and inverse sigmoidal transform. Adopting the transforms of the gaussian_linear_uniform task, which you can inspect as per the code snippet in my previous comment, should work for this. A pleasant side-effect could also be that SNLE runs faster.

As for the summary stats question: I think it would be fine to implement this alternative scenario using the same class. I did something similar for the Bernoulli GLM example, for which there is an example with raw features as well as summary stats (as well as for the SLCP with/without distractors).

psteinb commented 2 years ago

I digged a bit into the issue on

transformations to unconstrained space, e.g. a combination of affine and inverse sigmoidal transform.

It took me quite some time to figure out which part of sbibm generates these transforms that you referred to there. I guess you meant the following:

{'parameters': _InverseTransform(IndependentTransform(ComposeTransform(
    SigmoidTransform(),
    AffineTransform()
), 1))}

If I understood correctly, the get_log_prob_fn function in pyro.py generates this transforms dictionary based on what it can infer from the pyro model each task represents. As my simulator is not quite standard, I guess it didn't infer correctly. In other words, if the following lines:

# project on the axes
first = torch.sum(samples, axis=-2)  # along y, onto x
second = torch.sum(samples, axis=-1)  # along x, onto y

# concatenate and return
value = torch.cat([first, second], axis=-1)

could be implemented in a pyro friendly way, I wouldn't have run into this issue. Because gaussian_linear_uniform finishes the simulator with a pyro.sample call, it appears to trigger the transforms mentioned above "correctly".

Bottom line

I removed the offending assert statement in my simulator and replaced it with a if-clause sentinel. It is not nice, but appears to work. :shrug:

For the time being, I left a commented _get_transforms in the code of my task. After some poking around, I could see something like the following fly inside noref_beam/task.py:_get_transforms:

# Warning: this is conceptual untested code
def _get_transforms( # ... ) -> Dict[str, Any]:
    prior_dist = self.get_prior_dist()
    value = { 'parameters' : biject_to(prior_dist.support).inv }
    return value

But I lack the sbibm or pyro skills at this point to oversee the consequences of this monkey patch. Thoughts appreciated.

jan-matthis commented 2 years ago

Great!

The custom transforms you wrote are exactly what I was getting at. I think both options are fine -- you could either specify the transforms manually or use biject_to(prior_dist.support).inv. The advantage of the former is that it's immediately obvious what is going on, whereas the advantage of the second is that it's easier to adapt, e.g., if one wanted to run the task with a different prior.

I'd probably go for the more general solution and add a test for it, checking it against the manual transform and asserting equivalence of transforms.

psteinb commented 2 years ago

I worked on this again. Turns out the result of the biject_to(prior_dist.support).inv cannot be inspected too easily. I took a more data-driven approach in the tests via a roundtrip. ;-)

psteinb commented 2 years ago

Alright, running the python ./setup.py build_ext command locally on a fc35 box appears to work. The static linking to the MKL fails as I don't have it installed and the setup.py has it hardcoded.

jan-matthis commented 2 years ago

Thanks for getting back to this!

Could you try whether tests pass when you install MKL, or alternatively, (temporarily) remove it as a requirement?

psteinb commented 2 years ago

Sorry for getting back to this with some delay. I removed the MKL entirely now. If you still see some of it left-in, please let me know.