ColmTalbot / wcosmo

Backend agnostic gwcosmology tools
https://wcosmo.readthedocs.io/en/latest/
MIT License
2 stars 1 forks source link

JittedLikelihood cosmo implementation failed #20

Open HuiTong5 opened 1 week ago

HuiTong5 commented 1 week ago

I got an error

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float64[].
The error occurred while tracing the function generic_bilby_likelihood_function at /usr/local/lib/python3.10/dist-packages/gwpopulation/experimental/jax.py:8 for jit. This concrete value was not available in Python because it depends on the value of the argument parameters['H0'].

when using cosmo implementation within JittedLikelihod which can be reproduced with colab notebook that simply follows gwpopulation spectral siren example.

ColmTalbot commented 6 days ago

I'm able to reproduce this, I'm a little confused about why this happens though. Astropy is clearly adding units somewhere, which is not (currently) JAX-traceable, so I may need to work a little to hide the units. Here's way more details

---------------------------------------------------------------------------
TracerArrayConversionError                Traceback (most recent call last)
<timed eval> in <module>

[/usr/local/lib/python3.10/dist-packages/gwpopulation/experimental/jax.py](https://localhost:8080/#) in log_likelihood_ratio(self)
     93     def log_likelihood_ratio(self):
     94         return float(
---> 95             np.nan_to_num(self.likelihood_func(self.parameters, **self.kwargs))
     96         )

    [... skipping hidden 11 frame]

13 frames
[/usr/local/lib/python3.10/dist-packages/gwpopulation/experimental/jax.py](https://localhost:8080/#) in generic_bilby_likelihood_function(likelihood, parameters, use_ratio)
     22     likelihood.parameters.update(parameters)
     23     if use_ratio:
---> 24         return likelihood.log_likelihood_ratio()
     25     else:
     26         return likelihood.log_likelihood()

[/usr/local/lib/python3.10/dist-packages/gwpopulation/hyperpe.py](https://localhost:8080/#) in log_likelihood_ratio(self)
    186 
    187     def log_likelihood_ratio(self):
--> 188         ln_l, variance = self.ln_likelihood_and_variance()
    189         ln_l = xp.nan_to_num(ln_l, nan=-xp.inf)
    190         ln_l -= xp.nan_to_num(xp.inf * (self.maximum_uncertainty < variance), nan=0)

[/usr/local/lib/python3.10/dist-packages/gwpopulation/hyperpe.py](https://localhost:8080/#) in ln_likelihood_and_variance(self)
    176         self.parameters, added_keys = self.conversion_function(self.parameters)
    177         self.hyper_prior.parameters.update(self.parameters)
--> 178         ln_bayes_factors, variances = self._compute_per_event_ln_bayes_factors()
    179         ln_l = xp.sum(ln_bayes_factors)
    180         variance = xp.sum(variances)

[/usr/local/lib/python3.10/dist-packages/gwpopulation/hyperpe.py](https://localhost:8080/#) in _compute_per_event_ln_bayes_factors(self, return_uncertainty)
    203 
    204     def _compute_per_event_ln_bayes_factors(self, return_uncertainty=True):
--> 205         weights = self.hyper_prior.prob(self.data) / self.sampling_prior
    206         expectation = xp.mean(weights, axis=-1)
    207         if return_uncertainty:

[/usr/local/lib/python3.10/dist-packages/gwpopulation/experimental/cosmo_models.py](https://localhost:8080/#) in prob(self, data, **kwargs)
    160         """
    161 
--> 162         data, jacobian = self.detector_frame_to_source_frame(data)
    163         probability = super().prob(data, **kwargs)
    164         probability /= jacobian

[/usr/local/lib/python3.10/dist-packages/gwpopulation/experimental/cosmo_models.py](https://localhost:8080/#) in detector_frame_to_source_frame(self, data, **parameters)
     96         samples = dict()
     97         if "luminosity_distance" in data.keys():
---> 98             cosmo = self.cosmology(self.parameters)
     99             samples["redshift"] = z_at_value(
    100                 cosmo.luminosity_distance,

[/usr/local/lib/python3.10/dist-packages/gwpopulation/experimental/cosmo_models.py](https://localhost:8080/#) in cosmology(self, parameters)
     70             return self._cosmo
     71         else:
---> 72             return self._cosmo(**self.cosmology_variables(parameters))
     73 
     74     def detector_frame_to_source_frame(self, data, **parameters):

/usr/local/lib/python3.10/dist-packages/wcosmo/astropy.py in __init__(self, H0, Om0, Tcmb0, Neff, m_nu, Ob0, w0, name, meta)

[/usr/local/lib/python3.10/dist-packages/astropy/cosmology/parameter/_core.py](https://localhost:8080/#) in __set__(self, cosmology, value)
    196 
    197         # Validate value, generally setting units if present
--> 198         value = self.validate(cosmology, copy.deepcopy(value))
    199 
    200         # Make the value read-only, if ndarray-like

[/usr/local/lib/python3.10/dist-packages/astropy/cosmology/parameter/_core.py](https://localhost:8080/#) in validate(self, cosmology, value)
    239             (yes, that parameter order).
    240         """
--> 241         return self._fvalidate(cosmology, self, value)
    242 
    243     @staticmethod

[/usr/local/lib/python3.10/dist-packages/astropy/cosmology/parameter/_converter.py](https://localhost:8080/#) in _validate_to_scalar(cosmology, param, value)
     82 def _validate_to_scalar(cosmology, param, value):
     83     """"""
---> 84     value = _validate_with_unit(cosmology, param, value)
     85     if not value.isscalar:
     86         raise ValueError(f"{param.name} is a non-scalar quantity")

[/usr/local/lib/python3.10/dist-packages/astropy/cosmology/parameter/_converter.py](https://localhost:8080/#) in _validate_with_unit(cosmology, param, value)
     68     if param.unit is not None:
     69         with u.add_enabled_equivalencies(param.equivalencies):
---> 70             value = u.Quantity(value, param.unit)
     71     return value
     72 

[/usr/local/lib/python3.10/dist-packages/astropy/units/quantity.py](https://localhost:8080/#) in __new__(cls, value, unit, dtype, copy, order, subok, ndmin)
    542                     copy = COPY_IF_NEEDED  # copy will be made in conversion at end
    543 
--> 544         value = np.array(
    545             value, dtype=dtype, copy=copy, order=order, subok=True, ndmin=ndmin
    546         )

[/usr/local/lib/python3.10/dist-packages/jax/_src/core.py](https://localhost:8080/#) in __array__(self, *args, **kw)
    682 
    683   def __array__(self, *args, **kw):
--> 684     raise TracerArrayConversionError(self)
    685 
    686   def __dlpack__(self, *args, **kw):

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float64[].
The error occurred while tracing the function generic_bilby_likelihood_function at /usr/local/lib/python3.10/dist-packages/gwpopulation/experimental/jax.py:8 for jit. This concrete value was not available in Python because it depends on the value of the argument parameters['H0'].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
ColmTalbot commented 6 days ago

I transferred this issue as this seems like the more relevant repo.