lincc-frameworks / tdastro

MIT License
4 stars 0 forks source link

Update the core logic for updating parameters #56

Closed jeremykubica closed 3 months ago

jeremykubica commented 3 months ago

This is a prefactoring PR for some JAX related changes. The goal is to modularize how the parameters are set. The changes should not impact current users and (mostly) be behind the scenes.

What is changing:

1) ParameterSource classes is extended to track named information like source_type or fixed. Previously this was tracked in an array entry and accessed with a magic index number. This should make it easier to see what the code is doing and add new attributes to the setters.

2) ParameterizedNode attributes can no longer be set by arbitrary methods of other ParameterizedNodes. This was preventing us from easily checking that the evaluation graph was a DAG.

3) The results of a FunctionNode are now accessed through their attributes. This prevents problems where the same parameter is fed into multiple nodes and resampled each time (in the compute() call). Previously this could lead to inconsistent results.

4) Results of a FunctionNode are collected via a function _save_result_parameters() to make overloading compute() easier.

5) Fixed the seed setting logic in the numpy and Jax function modules to account for the changes above.