I am happy to do this, but raising issue to discuss first.
The search has update functionality, which interacts with a lot of different modules:
def perform_update(
self,
model: AbstractPriorModel,
analysis: Analysis,
during_analysis: bool,
search_internal=None,
) -> Samples:
"""
Perform an update of the non-linear search's model-fitting results.
This occurs every `iterations_per_update` of the non-linear search and once it is complete.
The update performs the following tasks (if the settings indicate they should be performed):
1) Visualize the search results.
2) Visualize the maximum log likelihood model using model-specific visualization implented via the `Analysis`
object.
3) Perform profiling of the analysis object `log_likelihood_function` and ouptut run-time information.
4) Output the `search.summary` file which contains information on model-fitting so far.
5) Output the `model.results` file which contains a concise text summary of the model results so far.
Parameters
----------
model
The model which generates instances for different points in parameter space.
analysis
Contains the data and the log likelihood function which fits an instance of the model to the data, returning
the log likelihood the `NonLinearSearch` maximizes.
during_analysis
If the update is during a non-linear search, in which case tasks are only performed after a certain number
of updates and only a subset of visualization may be performed.
"""
self.iterations += self.iterations_per_update
if during_analysis:
self.logger.info(
f"""Fit Running: Updating results after {self.iterations} iterations (see output folder)."""
)
else:
self.logger.info(
"Fit Complete: Updating final results (see output folder)."
)
if not isinstance(self.paths, DatabasePaths) and not isinstance(
self.paths, NullPaths
):
self.timer.update()
samples = self.samples_from(model=model, search_internal=search_internal)
samples_summary = samples.summary()
try:
instance = samples_summary.instance
except exc.FitException:
return samples
if self.is_master:
self.paths.save_samples_summary(samples_summary=samples_summary)
samples_save = samples
samples_save = samples_save.samples_above_weight_threshold_from(
log_message=not during_analysis
)
self.paths.save_samples(samples=samples_save)
latent_samples = None
if (during_analysis and conf.instance["output"]["latent_during_fit"]) or (
not during_analysis and conf.instance["output"]["latent_after_fit"]
):
latent_samples = analysis.compute_latent_samples(samples_save)
if latent_samples:
self.paths.save_latent_samples(
latent_samples
)
self.perform_visualization(
model=model,
analysis=analysis,
samples_summary=samples_summary,
during_analysis=during_analysis,
search_internal=search_internal,
)
if self.should_profile:
self.logger.debug("Profiling Maximum Likelihood Model")
analysis.profile_log_likelihood_function(
paths=self.paths,
instance=instance,
)
self.logger.debug("Outputting model result")
try:
log_likelihood_function = jax_wrapper.jit(
analysis.log_likelihood_function,
)
log_likelihood_function(instance=instance)
start = time.time()
log_likelihood_function(instance=instance)
log_likelihood_function_time = time.time() - start
self.paths.save_summary(
samples=samples,
latent_samples=latent_samples,
log_likelihood_function_time=log_likelihood_function_time,
)
except exc.FitException:
pass
if not during_analysis and self.remove_state_files_at_end:
self.logger.debug("Removing state files")
return samples
def perform_visualization(
self,
model: AbstractPriorModel,
analysis: AbstractPriorModel,
samples_summary: SamplesSummary,
during_analysis: bool,
search_internal=None,
):
"""
Perform visualization of the non-linear search's model-fitting results.
This occurs every `iterations_per_update` of the non-linear search, when the search is complete and can
also be forced to occur even though a search is completed on a rerun, to update the visualization
with different `matplotlib` settings.
The update performs the following tasks (if the settings indicate they should be performed):
1) Visualize the maximum log likelihood model using model-specific visualization implented via the `Analysis`
object.
2) Visualize the search results.
Parameters
----------
model
The model which generates instances for different points in parameter space.
analysis
Contains the data and the log likelihood function which fits an instance of the model to the data, returning
the log likelihood the `NonLinearSearch` maximizes.
during_analysis
If the update is during a non-linear search, in which case tasks are only performed after a certain number
of updates and only a subset of visualization may be performed.
"""
self.logger.debug("Visualizing")
if analysis.should_visualize(paths=self.paths, during_analysis=during_analysis):
analysis.visualize(
paths=self.paths,
instance=samples_summary.instance,
during_analysis=during_analysis,
)
analysis.visualize_combined(
paths=self.paths,
instance=samples_summary.instance,
during_analysis=during_analysis,
)
if analysis.should_visualize(paths=self.paths, during_analysis=during_analysis):
if not isinstance(self.paths, NullPaths):
samples = self.samples_from(
model=model, search_internal=search_internal
)
self.plot_results(samples=samples)
I am going to make a SearchUpdate class:
class SearchUpdate:
def __init__(self, model, analysis, search_internal):
....
This should allow me to write a method for each specific thing that is output and reuse functionality for cleanness.
Pretty sure this is a good idea but just in case any better refactors spring to your mind putting it here first.
I am happy to do this, but raising issue to discuss first.
The search has update functionality, which interacts with a lot of different modules:
I am going to make a
SearchUpdate
class:This should allow me to write a method for each specific thing that is output and reuse functionality for cleanness.
Pretty sure this is a good idea but just in case any better refactors spring to your mind putting it here first.