Closed jobrachem closed 3 months ago
We should think about the batching and how/whether it generalizes to different models. right now it assumes iid data.
maybe give it a different name that clearly identifies that this assumes iid data
I had a quick look at the code again. This is the function that does the batching. It is applied to Var
objects that are marked as Var.observed
. To be honest, I currently don't understand how this implies an assumption of iid data.
def batched_nodes(nodes: dict[str, Array], batch_indices: Array) -> dict[str, Array]:
"""Returns a subset of the graph state using the given batch indices."""
return jax.tree_util.tree_map(lambda x: x[batch_indices, ...], nodes)
It assumes that each node marked as observed can be indexed in the same way.
Consider a hierarchical model (and do not worry about the validity of the estimation):
Let's assume that $Y{ij}$ is the outcome variable for the $i$-th level-1 unit in the $j$-th level-2 unit. The covariates at level-1 and level-2 are represented by $X{ij}$ and $W_{j}$, respectively. The model can be represented as follows:
Level-1 model (within groups): $$Y{ij} = \beta{0j} + \beta{1j}X{ij} + \epsilon{ij}$$ where $\epsilon{ij}$ is the level-1 error term.
Level-2 model (between groups): $$\beta{0j} = \gamma{00} + \gamma{01}W{j} + u{0j}$$ $$\beta{1j} = \gamma{10} + \gamma{11}W{j} + u{1j}$$ where $u{0j}$ and $u{1j}$ are level-2 error terms.
The design matrices $\boldsymbol{W}$ and $\boldsymbol{X}$, arising from stacking the covariates, have different number of rows and cannot be indexed in the same way.
Thanks for the clarification @wiep !
It assumes that each node marked as observed can be indexed in the same way.
Indeed it does. I would call this a "flat" way of indexing, as opposed to "nested" indexing. Pointing this out in a way that cannot be missed or is at least hard to miss would be important. This limitation only applies to batching though, the rest of the function does not make such an assumption. So I am not certain whether changing the name of the function to something like optim_flat
would be necessary - it seems a bit overly restrictive to me.
Maybe the function can be designed in a way that allows the use of different, also user-supplied batching strategies. I would have to take a closer look at the code again to see whether this may even be easy to implement.
Consider a hierarchical model (and do not worry about the validity of the estimation):
Let's assume that Yij is the outcome variable for the i-th level-1 unit in the j-th level-2 unit. The covariates at level-1 and level-2 are represented by Xij and Wj, respectively. The model can be represented as follows:
Level-1 model (within groups): Yij=β0j+β1jXij+ϵij where ϵij is the level-1 error term.
Level-2 model (between groups): β0j=γ00+γ01Wj+u0j β1j=γ10+γ11Wj+u1j where u0j and u1j are level-2 error terms.
The design matrices W and X, arising from stacking the covariates, have different number of rows and cannot be indexed in the same way.
I do see the point that this model, like it is, would not fit into the batching scheme. You can, however, write down the model in a "flat" way that allows the batching to work. Still, of course that does not eliminate the fact that the batch_size
argument of the proposed optim
function has this limitation. This makes it even clearer how important it would be to alert users to the way batching is done.
I think that having different batching strategies would be really nice. Consider also the case in which you have a binary response and a really unbalanced dataset with a lot of 0s and only a few 1s. You might want a batching strategy that guarantees to always have some 1s in each batch
We will use optim_flat
for now, this also gives us room to favor a different implementation under the name optim
later without backwards compatibility issues.
@wiep if you got a notification for this, you can ignore it. @GianmarcoCallegher is doing the review.
Do we want to always have the negative log probability as the function to optimize or do we want to provide another parameter which could be the function to optimize, (e.g. RMSE)?
Do we want to always have the negative log probability as the function to optimize or do we want to provide another parameter which could be the function to optimize, (e.g. RMSE)?
Also a really good point. I guess it would be relatively easy to let users provide their own function for optimization.
Okay, I finally implemented a big update. Now we have the following:
test_model
. If supplied, the test loss is used to determine early stopping. For fully Bayesian optimization I think using a test model usually be a little weird, because it affects the balance of likelihood and prior. But there might be cases in which one would want to fit a likelihood model with fully constant priors.Here's a notebook with a minimal example for testing:
Do you think the current defaults are good?
stopper
are quite arbitrarily chosen right now. save_position_history=True
restore_best_position=True
prune_history=True
Thank you very much, @jobrachem.
Some general remarks:
model_test = model_test if model_test is not None else model
. If there is no validation set, you just run the model for all epochs on the training setWhat do you think? ❤️
Thank you very much for the speedy review @GianmarcoCallegher ! Your suggestions make sense to me, I will implement them.
Okay, I finally implemented a big update. Now we have the following:
- Batching is implemented as suggested by @GianmarcoCallegher (i.e. randomly assembled batches in each iteration) with the modification suggested by @wiep (i.e. if the number of observations is not divisible by the batch size, the last batch is simply dropped in that iteration).
- The function now saves a history of training and (if applicable) testing loss, as well as a history of parameter values. The latter can be turned off to save memory.
- Early stopping is now implement with patience.
- If the users wishes, they can supply a
test_model
. If supplied, the test loss is used to determine early stopping. For fully Bayesian optimization I think using a test model usually be a little weird, because it affects the balance of likelihood and prior. But there might be cases in which one would want to fit a likelihood model with fully constant priors.Here's a notebook with a minimal example for testing:
Do you think the current defaults are good?
- Especially the defaults for the
stopper
are quite arbitrarily chosen right now.save_position_history=True
restore_best_position=True
prune_history=True
Concerning the default values, I think they are totally fine. Maybe I would not have a default number of iterations, which has to be specified by the user
Over lunch, @wiep mentioned the practice of rescaling the likelihood when using mini-batching in order to ensure that the prior does not get more power than warranted. I think this should be implemented in our function here, too. It can be used both in general for minibatching, and also in order to correct for a potentially smaller sample size in a validation model.
Over lunch, @wiep mentioned the practice of rescaling the likelihood when using mini-batching in order to ensure that the prior does not get more power than warranted. I think this should be implemented in our function here, too. It can be used both in general for minibatching, and also in order to correct for a potentially smaller sample size in a validation model.
Of course, if the objective is the log joint 😄
@GianmarcoCallegher I implemented my changes. If you are satisfied, please squash & merge 😊
Everything LGTM, @jobrachem . I just left one final comment (probably it is not an error, I am just confused)
Thanks @GianmarcoCallegher ! See my reply above :)
This PR introduces the function
gs.optim
.This function provides a quick-and-easy way to apply stochastic gradient descent to a
lsl.Model
- which can be nice for providing a "warm start" for MCMC sampling. All loops and batching are implemented in JAX.Optimization can be applied on the whole dataset or in batches. By default, the Adam optimizer is used with a fixed learning rate of
0.1
. Users can supply an optimizer of their choice from theoptax
library.The PR includes documentation and unit tests.
Here's a notebook for playing around: 051-optim.ipynb.zip