furcelay / gigalens

Gradient Informed, GPU Accelerated Lens modelling (GIGALens) -- a package for fast Bayesian inference on strong gravitational lenses.
https://giga-lens.github.io/gigalens
MIT License
0 stars 0 forks source link

Better treatment of constant parameters #3

Closed furcelay closed 1 month ago

furcelay commented 6 months ago

Currently cannot include a fully constant model as the prior will be an empty dictionary.

Solutions:

  1. Use bijectors to inject empty structures into the prior samples.
  2. Add the empty structures during simulation.

Also need a way to know where to inject this structure, this is simple with dictionaries.

furcelay commented 6 months ago

Maybe can use a class that handles the prior construction to do this. Can be helpful to convert a list of models into a dictionary and to implement solution 1.

Seems to need many bijection operations. Should evaluate overhead over solution 2. Here an example:

class Model:

    def __init__(self, profile, params):
        self.profile = profile
        self.variables = {}
        self.constants = {}
        self.prior = None
        for k in params.keys():
            if k not in profile.params:
                raise RuntimeError(f"Unknown parameter '{k}' for model {profile}.")
        for k in profile.params:
            try:
                p = params[k]
            except KeyError:
                raise RuntimeError(f"Missing parameter '{k}' for model {profile}.")
            if isinstance(p, tfd.Distribution):
                self.variables[k] = p
            else:
                try:
                    self.constants[k] = float(p)
                except ValueError:
                    raise RuntimeError(f"Invalid value {p} for parameter '{k}', should be number or distribution.")
        if self.variables:
            self.prior = tfd.JointDistributionNamed(self.variables)

    def __repr__(self):
        return f"{self.profile}(vars:{list(self.variables.keys())},const:{list(self.constants.keys())})"

class CompoundModel:
    """
        lenses:    {1: prof1,         2: prof2,    ...}
        prior:     {1: {p1, p2 , p3}, 2: {p1, p2}, ...}
        constants: {1: {p4},          2: {},       ...}
    """
    def __init__(self, models=None):
        if models is None:
            models = []
        self.models = models
        self.keys = [str(i) for i in range(len(models))]
        self.profiles = {str(i): m.profile for i, m in enumerate(models)}
        self.constants = {str(i): m.constants for i, m in enumerate(models)}

        priors = {str(i): m.prior for i, m in enumerate(models) if m.prior is not None}
        self.prior = None
        self.pack_example = {}
        if priors:
            self.prior = tfd.JointDistributionNamed(priors)
            example = self.prior.sample()
            self.pack_example = example | {str(i): {} for i, m in enumerate(models) if m.prior is None}

    def __repr__(self):
        return f"CompoundModel({self.models})"

class LensModel:

    def __init__(self, lenses=CompoundModel(), sources=CompoundModel(), foreground=CompoundModel()):
        self.lenses = lenses
        self.sources = sources
        self.foreground = foreground

        self.phys_model = PhysicalModel(
            lenses.profiles,
            sources.profiles,
            foreground.profiles,
            lenses.constants,
            sources.constants,
            foreground.constants
            )

        priors = {}
        if lenses.prior is not None:
            priors['lenses'] = lenses.prior
        if sources.prior is not None:
            priors['sources'] = sources.prior
        if foreground.prior is not None:
            priors['foreground'] = foreground.prior

        self.prior = None
        if priors:
            prior = tfd.JointDistributionNamed(priors)
            example = prior.sample()
            extended_example = {'lenses': lenses.pack_example, 'sources': sources.pack_example, 'foreground': foreground.pack_example}
            extend_bij = tfb.Chain(
            [
                tfb.pack_sequence_as(extended_example),
                tfb.Invert(tfb.pack_sequence_as(example)),
            ]
        )
            self.prior = tfd.TransformedDistribution(
                prior,
                extend_bij
                )
            size = int(tf.size(tf.nest.flatten(example)))
            self.pack_bij =  tfb.Chain([
                tfb.pack_sequence_as(extended_example),
                tfb.Split(size),
                tfb.Reshape(event_shape_out=(-1,), event_shape_in=(size, -1)),
                tfb.Transpose(perm=(1, 0)),
            ])
            self.unconstraining_bij = tfb.Chain([
                extend_bij,
                prior.experimental_default_event_space_bijector(),
                tfb.Invert(extend_bij)
                ])
            self.bij = tfb.Chain([self.unconstraining_bij, self.pack_bij])

    def __repr__(self):
        return f"lenses: {self.lenses} | sources: {self.sources} | foreground: {self.foreground}"
furcelay commented 1 month ago

Solved with new prior functionality