sbi-dev / sbi

Simulation-based inference toolkit
https://sbi-dev.github.io/sbi/
Apache License 2.0
571 stars 143 forks source link

Parameter inputs may not conform into a single tensor #218

Open alvorithm opened 4 years ago

alvorithm commented 4 years ago

Background

Our current idea of the parameters of a simulator is either something structured (such as a cross-section matrix) or a plurality of scalars and vectors that can be joined into one long vector and whose prior is a set of independent priors over diffferent coordinates (@janfb?).

Goal

It should be possible to have differently-shaped parameters such as e.g. a positive rate, a covariance matrix, and a speed vector passed in, each of them with their own private factor prior (i.e. no need to build an Independent prior to become, formally, a prior over all of them). This is more expressive, and more clear.

Caveats

janfb commented 4 years ago

Actually, I think we cannot handle multi-dimensional parameters yet, e.g., no matrices. If a user wants to define a prior over a matrix she has to implement a prior wrapper that returns a flattened version and then change her simulator such that it takes this flattened vector and reshapes it into a matrix for simulation.

janfb commented 6 months ago

Update: we do not require an explicit prior object anymore, at least not for NPE. One can just pass pre-simulated theta and x.

theta must be a tensor though. Thus, if you have differently shaped "sub-priors" you need to take care of how to reshape this into a one-dimensional tensor.

janfb commented 1 month ago

Using pytrees as used in JAX as parameter objects might be an option for adding this feature.

How Pytrees Could Help

  1. Handling Complex Data Structures: Pytrees can easily accommodate complex and nested structures. For your scenario, where parameters might include a positive rate (scalar), a covariance matrix (2D tensor), and a speed vector (1D tensor), Pytrees allow you to organize these into a single nested structure. You can then define operations that will apply uniformly across this structure.
  2. Independent Priors: Each component of the Pytree can have its associated prior distribution. This allows you to manage and manipulate these priors independently, while still treating them as a coherent unit.
  3. Mapping Functions: JAX’s utilities for Pytrees allow you to map functions over the entire structure. This is useful for tasks like computing likelihoods, gradients, or transforming parameters, which can be done in a consistent and automated way without having to manually unpack and repack the structures.
  4. Neural Network Construction: When designing neural networks that process these structured inputs, Pytrees can help manage the complexity. You can define layers or transformations that operate on parts of the Pytree and then combine the results, respecting the structured nature of the inputs.