dirmeier / sbijax

Simulation-based inference in JAX
https://sbijax.rtfd.io
Apache License 2.0
18 stars 2 forks source link

clarification on make_maf function #39

Closed Jice-Zeng closed 1 month ago

Jice-Zeng commented 1 month ago

Hi Simon, I looked into the codes and have some questions about the design of the surjective flow. There is a function: def make_maf( n_dimension: int, n_layers: Optional[int] = 5, n_layer_dimensions: Optional[Iterable[int]] = None, hidden_sizes: Iterable[int] = (64, 64), activation: Callable = jax.nn.tanh, )

You provided an example: neural_network = make_maf(10, n_layer_dimensions=(10, 10, 5, 5, 5)) Does this mean we have a total of 5 layers, with the 1st and 2nd layers being dimensionality-preserving (keeping the dimension at 10), and the 3rd, 4th, and 5th layers being dimensionality-reducing (reducing the dimension by a factor of 2, from 10 to 5)?

If my initial dimension is 200 and I want the reduction factor to be 0.75, but only for the middle layer to be dimensionality-reducing, should the network be: neural_network = make_maf(200, n_layer_dimensions=(200, 200, 200x0.75, 200, 200))? I appreciate your assistance!

Best regard,

Jice

Jice-Zeng commented 1 month ago

it is should be make_maf(200, n_layer_dimensions=(200, 200, 200x0.75, 200x0.75, 200x0.75)), after dimension reduction at the 3rd layer, the remaining layers keep the same dimensions (dimensionality preserving)