liesel-devs / liesel

A probabilistic programming framework
https://liesel-project.org
MIT License
38 stars 2 forks source link

A function for applying stochastic gradient descent to a Liesel model #175

Closed jobrachem closed 3 months ago

jobrachem commented 7 months ago

This PR introduces the function gs.optim.

This function provides a quick-and-easy way to apply stochastic gradient descent to a lsl.Model - which can be nice for providing a "warm start" for MCMC sampling. All loops and batching are implemented in JAX.

Optimization can be applied on the whole dataset or in batches. By default, the Adam optimizer is used with a fixed learning rate of 0.1. Users can supply an optimizer of their choice from the optax library.

The PR includes documentation and unit tests.

Here's a notebook for playing around: 051-optim.ipynb.zip

jobrachem commented 6 months ago

We should think about the batching and how/whether it generalizes to different models. right now it assumes iid data.

jobrachem commented 6 months ago

maybe give it a different name that clearly identifies that this assumes iid data

jobrachem commented 6 months ago

I had a quick look at the code again. This is the function that does the batching. It is applied to Var objects that are marked as Var.observed. To be honest, I currently don't understand how this implies an assumption of iid data.

def batched_nodes(nodes: dict[str, Array], batch_indices: Array) -> dict[str, Array]:
     """Returns a subset of the graph state using the given batch indices."""
     return jax.tree_util.tree_map(lambda x: x[batch_indices, ...], nodes)
wiep commented 6 months ago

It assumes that each node marked as observed can be indexed in the same way.

Consider a hierarchical model (and do not worry about the validity of the estimation):

Let's assume that $Y{ij}$ is the outcome variable for the $i$-th level-1 unit in the $j$-th level-2 unit. The covariates at level-1 and level-2 are represented by $X{ij}$ and $W_{j}$, respectively. The model can be represented as follows:

Level-1 model (within groups): $$Y{ij} = \beta{0j} + \beta{1j}X{ij} + \epsilon{ij}$$ where $\epsilon{ij}$ is the level-1 error term.

Level-2 model (between groups): $$\beta{0j} = \gamma{00} + \gamma{01}W{j} + u{0j}$$ $$\beta{1j} = \gamma{10} + \gamma{11}W{j} + u{1j}$$ where $u{0j}$ and $u{1j}$ are level-2 error terms.

The design matrices $\boldsymbol{W}$ and $\boldsymbol{X}$, arising from stacking the covariates, have different number of rows and cannot be indexed in the same way.

jobrachem commented 6 months ago

Thanks for the clarification @wiep !

It assumes that each node marked as observed can be indexed in the same way.

Indeed it does. I would call this a "flat" way of indexing, as opposed to "nested" indexing. Pointing this out in a way that cannot be missed or is at least hard to miss would be important. This limitation only applies to batching though, the rest of the function does not make such an assumption. So I am not certain whether changing the name of the function to something like optim_flat would be necessary - it seems a bit overly restrictive to me.

Maybe the function can be designed in a way that allows the use of different, also user-supplied batching strategies. I would have to take a closer look at the code again to see whether this may even be easy to implement.

Consider a hierarchical model (and do not worry about the validity of the estimation):

Let's assume that Yij is the outcome variable for the i-th level-1 unit in the j-th level-2 unit. The covariates at level-1 and level-2 are represented by Xij and Wj, respectively. The model can be represented as follows:

Level-1 model (within groups): Yij=β0j+β1jXij+ϵij where ϵij is the level-1 error term.

Level-2 model (between groups): β0j=γ00+γ01Wj+u0j β1j=γ10+γ11Wj+u1j where u0j and u1j are level-2 error terms.

The design matrices W and X, arising from stacking the covariates, have different number of rows and cannot be indexed in the same way.

I do see the point that this model, like it is, would not fit into the batching scheme. You can, however, write down the model in a "flat" way that allows the batching to work. Still, of course that does not eliminate the fact that the batch_size argument of the proposed optim function has this limitation. This makes it even clearer how important it would be to alert users to the way batching is done.

GianmarcoCallegher commented 6 months ago

I think that having different batching strategies would be really nice. Consider also the case in which you have a binary response and a really unbalanced dataset with a lot of 0s and only a few 1s. You might want a batching strategy that guarantees to always have some 1s in each batch

jobrachem commented 6 months ago

We will use optim_flat for now, this also gives us room to favor a different implementation under the name optim later without backwards compatibility issues.

jobrachem commented 6 months ago

@wiep if you got a notification for this, you can ignore it. @GianmarcoCallegher is doing the review.

GianmarcoCallegher commented 5 months ago

Do we want to always have the negative log probability as the function to optimize or do we want to provide another parameter which could be the function to optimize, (e.g. RMSE)?

jobrachem commented 5 months ago

Do we want to always have the negative log probability as the function to optimize or do we want to provide another parameter which could be the function to optimize, (e.g. RMSE)?

Also a really good point. I guess it would be relatively easy to let users provide their own function for optimization.

jobrachem commented 4 months ago

Okay, I finally implemented a big update. Now we have the following:

  1. Batching is implemented as suggested by @GianmarcoCallegher (i.e. randomly assembled batches in each iteration) with the modification suggested by @wiep (i.e. if the number of observations is not divisible by the batch size, the last batch is simply dropped in that iteration).
  2. The function now saves a history of training and (if applicable) testing loss, as well as a history of parameter values. The latter can be turned off to save memory.
  3. Early stopping is now implement with patience.
  4. If the users wishes, they can supply a test_model. If supplied, the test loss is used to determine early stopping. For fully Bayesian optimization I think using a test model usually be a little weird, because it affects the balance of likelihood and prior. But there might be cases in which one would want to fit a likelihood model with fully constant priors.

Here's a notebook with a minimal example for testing:

075-optim.ipynb.zip

Do you think the current defaults are good?

  1. Especially the defaults for the stopper are quite arbitrarily chosen right now.
  2. save_position_history=True
  3. restore_best_position=True
  4. prune_history=True
GianmarcoCallegher commented 4 months ago

Thank you very much, @jobrachem.

Some general remarks:

What do you think? ❤️

jobrachem commented 4 months ago

Thank you very much for the speedy review @GianmarcoCallegher ! Your suggestions make sense to me, I will implement them.

GianmarcoCallegher commented 4 months ago

Okay, I finally implemented a big update. Now we have the following:

  1. Batching is implemented as suggested by @GianmarcoCallegher (i.e. randomly assembled batches in each iteration) with the modification suggested by @wiep (i.e. if the number of observations is not divisible by the batch size, the last batch is simply dropped in that iteration).
  2. The function now saves a history of training and (if applicable) testing loss, as well as a history of parameter values. The latter can be turned off to save memory.
  3. Early stopping is now implement with patience.
  4. If the users wishes, they can supply a test_model. If supplied, the test loss is used to determine early stopping. For fully Bayesian optimization I think using a test model usually be a little weird, because it affects the balance of likelihood and prior. But there might be cases in which one would want to fit a likelihood model with fully constant priors.

Here's a notebook with a minimal example for testing:

075-optim.ipynb.zip

Do you think the current defaults are good?

  1. Especially the defaults for the stopper are quite arbitrarily chosen right now.
  2. save_position_history=True
  3. restore_best_position=True
  4. prune_history=True

Concerning the default values, I think they are totally fine. Maybe I would not have a default number of iterations, which has to be specified by the user

jobrachem commented 3 months ago

Over lunch, @wiep mentioned the practice of rescaling the likelihood when using mini-batching in order to ensure that the prior does not get more power than warranted. I think this should be implemented in our function here, too. It can be used both in general for minibatching, and also in order to correct for a potentially smaller sample size in a validation model.

GianmarcoCallegher commented 3 months ago

Over lunch, @wiep mentioned the practice of rescaling the likelihood when using mini-batching in order to ensure that the prior does not get more power than warranted. I think this should be implemented in our function here, too. It can be used both in general for minibatching, and also in order to correct for a potentially smaller sample size in a validation model.

Of course, if the objective is the log joint 😄

jobrachem commented 3 months ago

@GianmarcoCallegher I implemented my changes. If you are satisfied, please squash & merge 😊

GianmarcoCallegher commented 3 months ago

Everything LGTM, @jobrachem . I just left one final comment (probably it is not an error, I am just confused)

jobrachem commented 3 months ago

Thanks @GianmarcoCallegher ! See my reply above :)