furcelay / gigalens

Gradient Informed, GPU Accelerated Lens modelling (GIGALens) -- a package for fast Bayesian inference on strong gravitational lenses.
https://giga-lens.github.io/gigalens
MIT License
0 stars 0 forks source link

Better cache treatment for precomputed models (MassSeries) #5

Closed furcelay closed 1 month ago

furcelay commented 6 months ago

MassSeries currently only stores the precomputed values for a single (x, y) grid. In case pixels and point sources are used as constraints, only pixels are precomputed.

Can use a dictionary as cache in TensorFlow with keys hash(tuple(x.ref(), y.ref())) but cannot hash a JAX array. ¿maybe a list with (x, y) and test equality?

cache_xy = [(x0, y0), (x1, y1), ...]
cache_deriv = [(f_x_0, f_y_0), (f_x_1, f_y_1), ...]

def get_cached(x, y, **kwargs):
    for (x_i, y_i), (f_x_i, f_y_i) in (cache_xy, cache_deriv):
        if (x == x_i) and (y == y_i):
            return f_x_i, f_y_i

    # not cached, compute and save
    f_x, f_y = precompute_deriv(x, y, **kwargs)
    cache_xy.append((x, y))
    cache_deriv.append((f_x, f_y))
    return f_x, f_y

Need to test if it is efficient on jit compilation and does not cause retracing.

furcelay commented 1 month ago

The cache is deprecated as it has poor functionality with jit compilation. It is better to use a new simulator for each grid.