OxIonics / ionics_fits

Small python fitting library with an emphasis on Atomic Molecular and Optical Physics
Apache License 2.0
1 stars 0 forks source link

reparametrise_model #138

Closed hartytp closed 8 months ago

hartytp commented 9 months ago

While ionics_fits does a pretty good job of providing convenient, flexible model parametrizations there will always be times when one wants something a little different. For example, sometimes one wants a sinusoid that's parameterized in terms of min/max value instead of amplitude and offset.

It would be nice to have a helper function which allows one to reparatmetrise models. The kind of interface I have in mind is that one passes in a dictionary of new parameters and a list of parameters which they replace, as well as a pair of functions which specify how the values / uncertainties for the replaced parameters are calculated.

hartytp commented 9 months ago

Started sketching something out...


def reparametrise_model(
    model: TModel,
    new_parameters: Dict[str, ModelParameter],
    replaced_parameters: List[str],
    reparametrise_func: Callable[[Dict[str, float]], Dict[str, float]],
    reparametrise_deriv_func: Callable[[Dict[str, float]], Dict[str, float]],
) -> Model:
    class ReparametrisedModel(model):
        def __init__(self, **kwargs):
            super.__init__(self, **kwargs)

            self.__reparametrise_func = reparametrise_func
            self.__reparametrise_deriv_func = reparametrise_deriv_func
            self.__original_parameters = [
                param_name for param_name in self.parameters.keys()
            ]

            if not all(
                [
                    parameter in self.parameters.keys()
                    for parameter in self.parameters.keys()
                ]
            ):
                raise ValueError(
                    "Replaced model parameters must exist in original model"
                )

            self.__replaced_parameters = {
                param_name: model.parameters.pop(param_name)
                for param_name in replaced_parameters
            }
            self.internal_parameters += self.__replaced_parameters.values()

            if any([key in self.parameters.keys() for key in new_parameters.keys()]):
                raise ValueError(
                    "New parameter names must not duplicate existing model parameter names"
                )
            self.parameters.update(new_parameters)

        def func(
            self,
            x: Array[("num_samples",), np.float64],
            param_values: Dict[str, float],
        ) -> Array[("num_y_channels", "num_samples"), np.float64]:
            re_param_values = self.__reparametrise_func(param_values)
            if re_param_values.keys() != self.replaced_parameters.keys():
                raise ValueError("...")

            param_values = {
                param_name: param_value
                for param_name, param_value in param_values.items()
                if param_name in self.__original_parameters
            }
            param_values.update(re_param_vales)
            return super().func(x=x, param_values=param_values)