rhayes777 / PyAutoFit

PyAutoFit: Classy Probabilistic Programming
https://pyautofit.readthedocs.io/
MIT License
60 stars 11 forks source link

Move search update to new class #1003

Open Jammy2211 opened 6 months ago

Jammy2211 commented 6 months ago

I am happy to do this, but raising issue to discuss first.

The search has update functionality, which interacts with a lot of different modules:


    def perform_update(
        self,
        model: AbstractPriorModel,
        analysis: Analysis,
        during_analysis: bool,
        search_internal=None,
    ) -> Samples:
        """
        Perform an update of the non-linear search's model-fitting results.

        This occurs every `iterations_per_update` of the non-linear search and once it is complete.

        The update performs the following tasks (if the settings indicate they should be performed):

        1) Visualize the search results.
        2) Visualize the maximum log likelihood model using model-specific visualization implented via the `Analysis`
           object.
        3) Perform profiling of the analysis object `log_likelihood_function` and ouptut run-time information.
        4) Output the `search.summary` file which contains information on model-fitting so far.
        5) Output the `model.results` file which contains a concise text summary of the model results so far.

        Parameters
        ----------
        model
            The model which generates instances for different points in parameter space.
        analysis
            Contains the data and the log likelihood function which fits an instance of the model to the data, returning
            the log likelihood the `NonLinearSearch` maximizes.
        during_analysis
            If the update is during a non-linear search, in which case tasks are only performed after a certain number
            of updates and only a subset of visualization may be performed.
        """

        self.iterations += self.iterations_per_update
        if during_analysis:
            self.logger.info(
                f"""Fit Running: Updating results after {self.iterations} iterations (see output folder)."""
            )
        else:
            self.logger.info(
                "Fit Complete: Updating final results (see output folder)."
            )

        if not isinstance(self.paths, DatabasePaths) and not isinstance(
            self.paths, NullPaths
        ):
            self.timer.update()

        samples = self.samples_from(model=model, search_internal=search_internal)
        samples_summary = samples.summary()

        try:
            instance = samples_summary.instance
        except exc.FitException:
            return samples

        if self.is_master:
            self.paths.save_samples_summary(samples_summary=samples_summary)

            samples_save = samples
            samples_save = samples_save.samples_above_weight_threshold_from(
                log_message=not during_analysis
            )
            self.paths.save_samples(samples=samples_save)

            latent_samples = None

            if (during_analysis and conf.instance["output"]["latent_during_fit"]) or (
                not during_analysis and conf.instance["output"]["latent_after_fit"]
            ):
                latent_samples = analysis.compute_latent_samples(samples_save)

                if latent_samples:
                    self.paths.save_latent_samples(
                        latent_samples
                    )

            self.perform_visualization(
                model=model,
                analysis=analysis,
                samples_summary=samples_summary,
                during_analysis=during_analysis,
                search_internal=search_internal,
            )

            if self.should_profile:
                self.logger.debug("Profiling Maximum Likelihood Model")
                analysis.profile_log_likelihood_function(
                    paths=self.paths,
                    instance=instance,
                )

            self.logger.debug("Outputting model result")
            try:
                log_likelihood_function = jax_wrapper.jit(
                    analysis.log_likelihood_function,
                )
                log_likelihood_function(instance=instance)

                start = time.time()
                log_likelihood_function(instance=instance)
                log_likelihood_function_time = time.time() - start

                self.paths.save_summary(
                    samples=samples,
                    latent_samples=latent_samples,
                    log_likelihood_function_time=log_likelihood_function_time,
                )
            except exc.FitException:
                pass

            if not during_analysis and self.remove_state_files_at_end:
                self.logger.debug("Removing state files")

        return samples

    def perform_visualization(
        self,
        model: AbstractPriorModel,
        analysis: AbstractPriorModel,
        samples_summary: SamplesSummary,
        during_analysis: bool,
        search_internal=None,
    ):
        """
        Perform visualization of the non-linear search's model-fitting results.

        This occurs every `iterations_per_update` of the non-linear search, when the search is complete and can
        also be forced to occur even though a search is completed on a rerun, to update the visualization
        with different `matplotlib` settings.

        The update performs the following tasks (if the settings indicate they should be performed):

        1) Visualize the maximum log likelihood model using model-specific visualization implented via the `Analysis`
           object.
        2) Visualize the search results.

        Parameters
        ----------
        model
            The model which generates instances for different points in parameter space.
        analysis
            Contains the data and the log likelihood function which fits an instance of the model to the data, returning
            the log likelihood the `NonLinearSearch` maximizes.
        during_analysis
            If the update is during a non-linear search, in which case tasks are only performed after a certain number
            of updates and only a subset of visualization may be performed.
        """

        self.logger.debug("Visualizing")

        if analysis.should_visualize(paths=self.paths, during_analysis=during_analysis):
            analysis.visualize(
                paths=self.paths,
                instance=samples_summary.instance,
                during_analysis=during_analysis,
            )
            analysis.visualize_combined(
                paths=self.paths,
                instance=samples_summary.instance,
                during_analysis=during_analysis,
            )

        if analysis.should_visualize(paths=self.paths, during_analysis=during_analysis):
            if not isinstance(self.paths, NullPaths):
                samples = self.samples_from(
                    model=model, search_internal=search_internal
                )

                self.plot_results(samples=samples)

I am going to make a SearchUpdate class:

class SearchUpdate:

    def __init__(self, model, analysis, search_internal):

         ....

This should allow me to write a method for each specific thing that is output and reuse functionality for cleanness.

Pretty sure this is a good idea but just in case any better refactors spring to your mind putting it here first.