Closed furcelay closed 1 month ago
Maybe can use a class that handles the prior construction to do this. Can be helpful to convert a list of models into a dictionary and to implement solution 1.
Seems to need many bijection operations. Should evaluate overhead over solution 2. Here an example:
class Model:
def __init__(self, profile, params):
self.profile = profile
self.variables = {}
self.constants = {}
self.prior = None
for k in params.keys():
if k not in profile.params:
raise RuntimeError(f"Unknown parameter '{k}' for model {profile}.")
for k in profile.params:
try:
p = params[k]
except KeyError:
raise RuntimeError(f"Missing parameter '{k}' for model {profile}.")
if isinstance(p, tfd.Distribution):
self.variables[k] = p
else:
try:
self.constants[k] = float(p)
except ValueError:
raise RuntimeError(f"Invalid value {p} for parameter '{k}', should be number or distribution.")
if self.variables:
self.prior = tfd.JointDistributionNamed(self.variables)
def __repr__(self):
return f"{self.profile}(vars:{list(self.variables.keys())},const:{list(self.constants.keys())})"
class CompoundModel:
"""
lenses: {1: prof1, 2: prof2, ...}
prior: {1: {p1, p2 , p3}, 2: {p1, p2}, ...}
constants: {1: {p4}, 2: {}, ...}
"""
def __init__(self, models=None):
if models is None:
models = []
self.models = models
self.keys = [str(i) for i in range(len(models))]
self.profiles = {str(i): m.profile for i, m in enumerate(models)}
self.constants = {str(i): m.constants for i, m in enumerate(models)}
priors = {str(i): m.prior for i, m in enumerate(models) if m.prior is not None}
self.prior = None
self.pack_example = {}
if priors:
self.prior = tfd.JointDistributionNamed(priors)
example = self.prior.sample()
self.pack_example = example | {str(i): {} for i, m in enumerate(models) if m.prior is None}
def __repr__(self):
return f"CompoundModel({self.models})"
class LensModel:
def __init__(self, lenses=CompoundModel(), sources=CompoundModel(), foreground=CompoundModel()):
self.lenses = lenses
self.sources = sources
self.foreground = foreground
self.phys_model = PhysicalModel(
lenses.profiles,
sources.profiles,
foreground.profiles,
lenses.constants,
sources.constants,
foreground.constants
)
priors = {}
if lenses.prior is not None:
priors['lenses'] = lenses.prior
if sources.prior is not None:
priors['sources'] = sources.prior
if foreground.prior is not None:
priors['foreground'] = foreground.prior
self.prior = None
if priors:
prior = tfd.JointDistributionNamed(priors)
example = prior.sample()
extended_example = {'lenses': lenses.pack_example, 'sources': sources.pack_example, 'foreground': foreground.pack_example}
extend_bij = tfb.Chain(
[
tfb.pack_sequence_as(extended_example),
tfb.Invert(tfb.pack_sequence_as(example)),
]
)
self.prior = tfd.TransformedDistribution(
prior,
extend_bij
)
size = int(tf.size(tf.nest.flatten(example)))
self.pack_bij = tfb.Chain([
tfb.pack_sequence_as(extended_example),
tfb.Split(size),
tfb.Reshape(event_shape_out=(-1,), event_shape_in=(size, -1)),
tfb.Transpose(perm=(1, 0)),
])
self.unconstraining_bij = tfb.Chain([
extend_bij,
prior.experimental_default_event_space_bijector(),
tfb.Invert(extend_bij)
])
self.bij = tfb.Chain([self.unconstraining_bij, self.pack_bij])
def __repr__(self):
return f"lenses: {self.lenses} | sources: {self.sources} | foreground: {self.foreground}"
Solved with new prior functionality
Currently cannot include a fully constant model as the prior will be an empty dictionary.
Solutions:
Also need a way to know where to inject this structure, this is simple with dictionaries.