Open HuiTong5 opened 1 week ago
I'm able to reproduce this, I'm a little confused about why this happens though. Astropy is clearly adding units somewhere, which is not (currently) JAX-traceable, so I may need to work a little to hide the units. Here's way more details
---------------------------------------------------------------------------
TracerArrayConversionError Traceback (most recent call last)
<timed eval> in <module>
[/usr/local/lib/python3.10/dist-packages/gwpopulation/experimental/jax.py](https://localhost:8080/#) in log_likelihood_ratio(self)
93 def log_likelihood_ratio(self):
94 return float(
---> 95 np.nan_to_num(self.likelihood_func(self.parameters, **self.kwargs))
96 )
[... skipping hidden 11 frame]
13 frames
[/usr/local/lib/python3.10/dist-packages/gwpopulation/experimental/jax.py](https://localhost:8080/#) in generic_bilby_likelihood_function(likelihood, parameters, use_ratio)
22 likelihood.parameters.update(parameters)
23 if use_ratio:
---> 24 return likelihood.log_likelihood_ratio()
25 else:
26 return likelihood.log_likelihood()
[/usr/local/lib/python3.10/dist-packages/gwpopulation/hyperpe.py](https://localhost:8080/#) in log_likelihood_ratio(self)
186
187 def log_likelihood_ratio(self):
--> 188 ln_l, variance = self.ln_likelihood_and_variance()
189 ln_l = xp.nan_to_num(ln_l, nan=-xp.inf)
190 ln_l -= xp.nan_to_num(xp.inf * (self.maximum_uncertainty < variance), nan=0)
[/usr/local/lib/python3.10/dist-packages/gwpopulation/hyperpe.py](https://localhost:8080/#) in ln_likelihood_and_variance(self)
176 self.parameters, added_keys = self.conversion_function(self.parameters)
177 self.hyper_prior.parameters.update(self.parameters)
--> 178 ln_bayes_factors, variances = self._compute_per_event_ln_bayes_factors()
179 ln_l = xp.sum(ln_bayes_factors)
180 variance = xp.sum(variances)
[/usr/local/lib/python3.10/dist-packages/gwpopulation/hyperpe.py](https://localhost:8080/#) in _compute_per_event_ln_bayes_factors(self, return_uncertainty)
203
204 def _compute_per_event_ln_bayes_factors(self, return_uncertainty=True):
--> 205 weights = self.hyper_prior.prob(self.data) / self.sampling_prior
206 expectation = xp.mean(weights, axis=-1)
207 if return_uncertainty:
[/usr/local/lib/python3.10/dist-packages/gwpopulation/experimental/cosmo_models.py](https://localhost:8080/#) in prob(self, data, **kwargs)
160 """
161
--> 162 data, jacobian = self.detector_frame_to_source_frame(data)
163 probability = super().prob(data, **kwargs)
164 probability /= jacobian
[/usr/local/lib/python3.10/dist-packages/gwpopulation/experimental/cosmo_models.py](https://localhost:8080/#) in detector_frame_to_source_frame(self, data, **parameters)
96 samples = dict()
97 if "luminosity_distance" in data.keys():
---> 98 cosmo = self.cosmology(self.parameters)
99 samples["redshift"] = z_at_value(
100 cosmo.luminosity_distance,
[/usr/local/lib/python3.10/dist-packages/gwpopulation/experimental/cosmo_models.py](https://localhost:8080/#) in cosmology(self, parameters)
70 return self._cosmo
71 else:
---> 72 return self._cosmo(**self.cosmology_variables(parameters))
73
74 def detector_frame_to_source_frame(self, data, **parameters):
/usr/local/lib/python3.10/dist-packages/wcosmo/astropy.py in __init__(self, H0, Om0, Tcmb0, Neff, m_nu, Ob0, w0, name, meta)
[/usr/local/lib/python3.10/dist-packages/astropy/cosmology/parameter/_core.py](https://localhost:8080/#) in __set__(self, cosmology, value)
196
197 # Validate value, generally setting units if present
--> 198 value = self.validate(cosmology, copy.deepcopy(value))
199
200 # Make the value read-only, if ndarray-like
[/usr/local/lib/python3.10/dist-packages/astropy/cosmology/parameter/_core.py](https://localhost:8080/#) in validate(self, cosmology, value)
239 (yes, that parameter order).
240 """
--> 241 return self._fvalidate(cosmology, self, value)
242
243 @staticmethod
[/usr/local/lib/python3.10/dist-packages/astropy/cosmology/parameter/_converter.py](https://localhost:8080/#) in _validate_to_scalar(cosmology, param, value)
82 def _validate_to_scalar(cosmology, param, value):
83 """"""
---> 84 value = _validate_with_unit(cosmology, param, value)
85 if not value.isscalar:
86 raise ValueError(f"{param.name} is a non-scalar quantity")
[/usr/local/lib/python3.10/dist-packages/astropy/cosmology/parameter/_converter.py](https://localhost:8080/#) in _validate_with_unit(cosmology, param, value)
68 if param.unit is not None:
69 with u.add_enabled_equivalencies(param.equivalencies):
---> 70 value = u.Quantity(value, param.unit)
71 return value
72
[/usr/local/lib/python3.10/dist-packages/astropy/units/quantity.py](https://localhost:8080/#) in __new__(cls, value, unit, dtype, copy, order, subok, ndmin)
542 copy = COPY_IF_NEEDED # copy will be made in conversion at end
543
--> 544 value = np.array(
545 value, dtype=dtype, copy=copy, order=order, subok=True, ndmin=ndmin
546 )
[/usr/local/lib/python3.10/dist-packages/jax/_src/core.py](https://localhost:8080/#) in __array__(self, *args, **kw)
682
683 def __array__(self, *args, **kw):
--> 684 raise TracerArrayConversionError(self)
685
686 def __dlpack__(self, *args, **kw):
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float64[].
The error occurred while tracing the function generic_bilby_likelihood_function at /usr/local/lib/python3.10/dist-packages/gwpopulation/experimental/jax.py:8 for jit. This concrete value was not available in Python because it depends on the value of the argument parameters['H0'].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
I transferred this issue as this seems like the more relevant repo.
I got an error
when using cosmo implementation within JittedLikelihod which can be reproduced with colab notebook that simply follows gwpopulation spectral siren example.