google / bayesnf

Bayesian Neural Field models for prediction in large-scale spatiotemporal datasets
https://google.github.io/bayesnf/
Apache License 2.0
88 stars 8 forks source link

Integer division in MAP ensemble_size cause a crash when ensemble_size < device_count #28

Open fsaad opened 8 months ago

fsaad commented 8 months ago

https://github.com/google/bayesnf/blob/fb59400ab86aa16a548f6df566bc0d5ba6e19eb5/src/bayesnf/inference.py#L445

If ensemble_size < jax.device_count then 0 particles are fitted.

In terms of the API .fit silently fails, but .predict gives an error, since there is a min/max operation over empty arrays.

ValueError: zero-size array to reduction operation min which has no identity

/bayesnf/spatiotemporal.py in predict(self, table, quantiles)
    259   def predict(self, table, quantiles=(0.5,)):
    260     test_data = self.data_handler.get_test(table)
--> 261     return inference.predict_bnf(
    262         test_data,
    263         self.observation_model,

/bayesnf/inference.py in predict_bnf(features, observation_model, params, model_args, quantiles, ensemble_dims, approximate_quantiles)
    468     (means, scales) = forecast_params
    469     forecast_means = means
--> 470     forecast_quantiles = _get_percentile_normal(
    471         forecast_means,
    472         scales,

/bayesnf/inference.py in _get_percentile_normal(means, scales, quantiles, axis, approximate)
     82   for q in quantiles:
     83     forecast_quantiles.append(
---> 84         quantile_fn(means, scales[..., jnp.newaxis], q, axis)
     85     )
     86   return forecast_quantiles

/bayesnf/inference.py in _normal_quantile_via_root(means, scales, q, axis)
     31   res = tfp.math.find_root_chandrupatla(
     32       lambda x: n.cdf(x).mean(axis) - q,
---> 33       low=jnp.amin(means) - 5 * jnp.amax(scales),
     34       high=jnp.amax(means) + 5 * jnp.amax(scales),
     35       value_tolerance=1e-5,

jax/_src/numpy/reductions.py in min(a, axis, out, keepdims, initial, where)
    276         keepdims: bool = False, initial: ArrayLike | None = None,
    277         where: ArrayLike | None = None) -> Array:
--> 278   return _reduce_min(a, axis=_ensure_optional_axes(axis), out=out,
    279                      keepdims=keepdims, initial=initial, where=where)
    280 

jax/_src/numpy/reductions.py in _reduce_min(a, axis, out, keepdims, initial, where)
    268                 keepdims: bool = False, initial: ArrayLike | None = None,
    269                 where: ArrayLike | None = None) -> Array:
--> 270   return _reduction(a, "min", np.min, lax.min, np.inf, has_identity=False,
    271                     axis=axis, out=out, keepdims=keepdims,
    272                     initial=initial, where_=where, parallel_reduce=lax.pmin)

jax/_src/numpy/reductions.py in _reduction(a, name, np_fun, op, init_val, has_identity, preproc, bool_op, upcast_f16_for_computation, axis, dtype, out, keepdims, initial, where_, parallel_reduce, promote_integers)
     99     shape = np.shape(a)
    100     if not _all(shape[d] >= 1 for d in pos_dims):
--> 101       raise ValueError(f"zero-size array to reduction operation {name} which has no identity")
    102 
    103   result_dtype = dtype or dtypes.dtype(a)
fsaad commented 8 months ago

Related #19