Open fsaad opened 8 months ago
https://github.com/google/bayesnf/blob/fb59400ab86aa16a548f6df566bc0d5ba6e19eb5/src/bayesnf/inference.py#L445
If ensemble_size < jax.device_count then 0 particles are fitted.
ensemble_size < jax.device_count
In terms of the API .fit silently fails, but .predict gives an error, since there is a min/max operation over empty arrays.
.fit
.predict
ValueError: zero-size array to reduction operation min which has no identity /bayesnf/spatiotemporal.py in predict(self, table, quantiles) 259 def predict(self, table, quantiles=(0.5,)): 260 test_data = self.data_handler.get_test(table) --> 261 return inference.predict_bnf( 262 test_data, 263 self.observation_model, /bayesnf/inference.py in predict_bnf(features, observation_model, params, model_args, quantiles, ensemble_dims, approximate_quantiles) 468 (means, scales) = forecast_params 469 forecast_means = means --> 470 forecast_quantiles = _get_percentile_normal( 471 forecast_means, 472 scales, /bayesnf/inference.py in _get_percentile_normal(means, scales, quantiles, axis, approximate) 82 for q in quantiles: 83 forecast_quantiles.append( ---> 84 quantile_fn(means, scales[..., jnp.newaxis], q, axis) 85 ) 86 return forecast_quantiles /bayesnf/inference.py in _normal_quantile_via_root(means, scales, q, axis) 31 res = tfp.math.find_root_chandrupatla( 32 lambda x: n.cdf(x).mean(axis) - q, ---> 33 low=jnp.amin(means) - 5 * jnp.amax(scales), 34 high=jnp.amax(means) + 5 * jnp.amax(scales), 35 value_tolerance=1e-5, jax/_src/numpy/reductions.py in min(a, axis, out, keepdims, initial, where) 276 keepdims: bool = False, initial: ArrayLike | None = None, 277 where: ArrayLike | None = None) -> Array: --> 278 return _reduce_min(a, axis=_ensure_optional_axes(axis), out=out, 279 keepdims=keepdims, initial=initial, where=where) 280 jax/_src/numpy/reductions.py in _reduce_min(a, axis, out, keepdims, initial, where) 268 keepdims: bool = False, initial: ArrayLike | None = None, 269 where: ArrayLike | None = None) -> Array: --> 270 return _reduction(a, "min", np.min, lax.min, np.inf, has_identity=False, 271 axis=axis, out=out, keepdims=keepdims, 272 initial=initial, where_=where, parallel_reduce=lax.pmin) jax/_src/numpy/reductions.py in _reduction(a, name, np_fun, op, init_val, has_identity, preproc, bool_op, upcast_f16_for_computation, axis, dtype, out, keepdims, initial, where_, parallel_reduce, promote_integers) 99 shape = np.shape(a) 100 if not _all(shape[d] >= 1 for d in pos_dims): --> 101 raise ValueError(f"zero-size array to reduction operation {name} which has no identity") 102 103 result_dtype = dtype or dtypes.dtype(a)
Related #19
https://github.com/google/bayesnf/blob/fb59400ab86aa16a548f6df566bc0d5ba6e19eb5/src/bayesnf/inference.py#L445
If
ensemble_size < jax.device_count
then 0 particles are fitted.In terms of the API
.fit
silently fails, but.predict
gives an error, since there is a min/max operation over empty arrays.