flatironinstitute / nemos

NEural MOdelS, a statistical modeling framework for neuroscience.
https://nemos.readthedocs.io/
MIT License
82 stars 8 forks source link

Group Lasso - Pytree compatibility #78

Open BalzaniEdoardo opened 10 months ago

BalzaniEdoardo commented 10 months ago

As of now, the group Lasso regularizer is not compatible with pytrees, but only with arrays. The fundamental difference is that when we pass a model matrix in array format, parameters are grouped based on column indices; on the other hand, when we pass pytree parameters, groups are given by the dictionary structure itself, (i.e. params["group_1"], ... , params["group_n"], are the different groups).

We need to discuss if and how to maintain compatibility with both pytrees and arrays for group Lasso. Below an implementation of the group Lasso operator that works with pytrees only:

this is a proximal operator that is equivalent to the original implementation that works with pytrees parameters. It assumes a tree representation of the regularizer, which could be more flexible for cases in which we want to have a regularizer strength that is variable specific.

# should be added to src/nemos/proximal_operator.py
def prox_group_lasso_pytree(
    params: Tuple[DESIGN_INPUT_TYPE, jnp.ndarray], l2reg: DESIGN_INPUT_TYPE, scaling=1.0
):
    r"""Proximal operator for the l2 norm, i.e., block soft-thresholding operator.

     Parameters
    ----------
    params :
        The input. `params[0]` are the weights (a tree of JAX arrays or FeaturePytree).
        `params[1]` are the intercepts (a JAX array).
    l2reg :
        The regularization strength, which is a pytree with
        the same structure as `params[0]`.
    scaling :
        A scaling factor applied to the regularization term. Defaults to 1.0.

    Returns
    -------
    :
        The rescaled weights.

    Notes
    -----
    This function implements the proximal operator for a group-Lasso penalization which
    can be derived in analytical form.
    The proximal operator equation are,

    $$
    \text{prox}(\beta_g) = \text{min}_{\beta} \left[ \lambda  \sum\_{g=1}^G \Vert \beta_g \Vert_2 +
     \frac{1}{2} \Vert \hat{\beta} - \beta \Vert_2^2
    \right],
    $$
    where $G$ is the number of groups, and $\beta_g$ is the parameter vector
    associated with the $g$-th group.
    The analytical solution[^1] for the beta is,

    $$
    \text{prox}(\beta\_g) = \max \left(1 - \frac{\lambda \sqrt{p\_g}}{\Vert \hat{\beta}\_g \Vert_2},
     0\right) \cdot \hat{\beta}\_g,
    $$
    where $p_g$ is the dimensionality of $\beta\_g$ and $\hat{\beta}$ is typically the gradient step
    of the un-regularized optimization objective function. It's easy to see how the group-Lasso
    proximal operator acts as a shrinkage factor for the un-penalize update, and the half-rectification
    non-linearity that effectively sets to zero group of coefficients satisfying,
    $$
    \Vert \hat{\beta}\_g \Vert_2 \le \frac{1}{\lambda \sqrt{p\_g}}.
    $$

    [^1]:
        Yuan, Ming, and Yi Lin. "Model selection and estimation in regression with grouped variables."
        Journal of the Royal Statistical Society Series B: Statistical Methodology 68.1 (2006): 49-67.
    """
    # assume that the last axis are the features
    l2_norm = jax.tree_map(
        lambda xx: jnp.linalg.norm(xx, axis=-1, keepdims=True) / jnp.sqrt(xx.shape[-1]), params[0]
    )
    factor = jax.tree_map(lambda xx, yy: 1 - xx * scaling / yy, l2reg, l2_norm)
    factor = jax.tree_map(jax.nn.relu, factor)
    return jax.tree_map(lambda xx, yy: xx * yy, factor, params[0]), params[1]

A test that checks the equivalence between the pytree-based and array based implementation is the following:

# this should be added to tests/test_proximal_operator.py
def test_compare_group_lasso(example_data_prox_operator):
    """Compare the group lasso prox operators."""
    params, regularizer_strength, mask, scaling = example_data_prox_operator
    # create a pytree version of params
    params_tree = FeaturePytree(**{f"{k}": params[0][:, jnp.array(msk, dtype=bool)] for k, msk in enumerate(mask)})
    # create a regularizer tree with the same struct as params_tree
    treedef = jax.tree_util.tree_structure(params_tree)
    # make sure the leaves are arrays (otherwise FeaturePytree cannot be instantiated)
    alpha_tree = jax.tree_util.tree_unflatten(treedef, [jnp.atleast_1d(regularizer_strength)] * treedef.num_leaves)
    # compute updates using both functions
    updated_params = prox_group_lasso(params, regularizer_strength, mask, scaling)
    updated_params_tree = prox_group_lasso_pytree((params_tree, params[1]), alpha_tree, scaling)
    # check agreement
    check_updates = [
        jnp.all(updated_params[0][:, jnp.array(msk, dtype=bool)] == updated_params_tree[0][f"{k}"])
        for k, msk in enumerate(mask)
    ]
    assert all(check_updates)
    assert all(updated_params_tree[1] == updated_params[1])
wulfdewolf commented 3 months ago

If I can add to this discussion as a user: it would be very convenient for me to be able to apply GroupLasso to my Pytree and specify the groups (trough the mask) using feature labels instead of having to generate a mask by hand.