py-econometrics / pyfixest

Fast High-Dimensional Fixed Effects Regression in Python following fixest-syntax
https://py-econometrics.github.io/pyfixest/pyfixest.html
MIT License
123 stars 29 forks source link

Implement Core Algorithms in JAX #380

Open s3alfisc opened 3 months ago

s3alfisc commented 3 months ago

The other day @janosg suggested that implementing some of the core algorithms in JAX might provide exceptional performance improvements.

pyfixest's alternating projections algorithm is implemented here.

A vectorized numpy version of the algorithm is implemented in the pyhdfe library.

A very simple implementation of the MAP algorithm is sketched out in R in the lfe documentation (at the bottom of page 5):

demean <- function(v, fl) {
  Pv <- v
  oldv <- v - 1
  while (sqrt(sum((Pv - oldv)**2)) >= 1e-7) {
    oldv <- Pv
    for (f in fl) Pv <- Pv - ave(Pv, f)
  }
  Pv
}

Also tagging @jeffgortmaker, as I recall you suggesting using JAX to solve the MAP fixed point problem at some point =)

janosg commented 3 months ago

No guarantees but could be worth a try. The important thing would be to use lax control flow operators instead of Python loops and vectorize as much as possible.

s3alfisc commented 3 months ago

I might give it a try over the many holidays coming up here in NRW :D

juanitorduz commented 1 month ago

Which could be a good candidate algorithm to start with? I might also wanna try it :)

jeffgortmaker commented 1 month ago

Probably just simple iterative de-meaning. Here's something close that I threw together for my own research:

def jax_residualize(matrix, ids, iterations=100):
    """Iteratively residualize a matrix within groups. With more than one set of groups, simply de-mean up to a fixed
    number of iterations to allow for reverse mode differentiation (neither custom roots nor returning the same thing
    after a tolerance was reached seemed to be working).
    """
    import jax
    import jax.numpy as jnp

    def demean(values, groups):
        """Demean within groups."""
        group_sums = jax.ops.segment_sum(values, groups, groups.size)
        group_counts = jax.ops.segment_sum(jnp.ones_like(values), groups, groups.size)
        group_means = group_sums / jnp.where(group_counts == 0, 1, group_counts)
        return values - group_means.at[groups].get()

    def demean_all(values):
        """Demean within all sets of groups."""
        return jax.lax.scan(lambda v, g: (demean(v, g), None), values, ids.T)[0]

    def residualize_column(column):
        """Residualize a single column within groups."""
        if len(ids.shape) == 1:
            return demean(column, ids)
        if ids.shape[1] == 1:
            return demean(column, ids.at[:, 0].get())
        return jax.lax.scan(lambda x, _: (demean_all(x), None), column, None, length=iterations)[0]

    return jax.vmap(residualize_column)(matrix.T).T

You probably want to replace the fixed-length lax.scan with a lax.while_loop that terminates either after hitting a max number of iterations or after hitting a termination tolerance. And you can probably improve on not having a static groups size, etc.

(I wrote the above code to always do a fixed number of iterations for multiple dimensions of fixed effects because in my own work this is part of a function that I need to reverse-mode differentiate. And I couldn't manage to wrap lax.while_loop with a lax.custom_root in this setting---not sure why. If you manage to get reverse-mode differentiation working I'd love to see!)

juanitorduz commented 1 month ago

Wow! Thanks for the (fast!) input! I have encountered this issue in the past (see https://github.com/pyro-ppl/numpyro/pull/1731) and I have empirically seen the forward differentiation is much slower for large data sets. I guess that is expected in view of what they explain in https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#jvps-in-jax-code I am not an expert though 🙈

jeffgortmaker commented 1 month ago

Yep! Makes sense that differentiating each iteration of the fixed point would be slower.

I've had success with the backward differentiation-compatible lax.custom_root approach in other settings (applying the implicit function theorem to the root implied by the fixed point), but not sure why it wasn't working here (I was getting derivatives of zero). Could have just been a bug in my code. Also not an expert!

s3alfisc commented 1 month ago

Thanks @jeffgortmaker for chiming in and even providing some code! =) I agree that the demeaning algo is the one we should try first, as it is easily vectorized and it is the most costly algo to run.