MassSeries currently only stores the precomputed values for a single (x, y) grid. In case pixels and point sources are used as constraints, only pixels are precomputed.
Can use a dictionary as cache in TensorFlow with keys hash(tuple(x.ref(), y.ref())) but cannot hash a JAX array. ¿maybe a list with (x, y) and test equality?
cache_xy = [(x0, y0), (x1, y1), ...]
cache_deriv = [(f_x_0, f_y_0), (f_x_1, f_y_1), ...]
def get_cached(x, y, **kwargs):
for (x_i, y_i), (f_x_i, f_y_i) in (cache_xy, cache_deriv):
if (x == x_i) and (y == y_i):
return f_x_i, f_y_i
# not cached, compute and save
f_x, f_y = precompute_deriv(x, y, **kwargs)
cache_xy.append((x, y))
cache_deriv.append((f_x, f_y))
return f_x, f_y
Need to test if it is efficient on jit compilation and does not cause retracing.
MassSeries
currently only stores the precomputed values for a single (x, y) grid. In case pixels and point sources are used as constraints, only pixels are precomputed.Can use a dictionary as cache in TensorFlow with keys
hash(tuple(x.ref(), y.ref()))
but cannot hash a JAX array. ¿maybe a list with (x, y) and test equality?Need to test if it is efficient on jit compilation and does not cause retracing.