dfm / tess-atlas

MIT License
9 stars 8 forks source link

inference trace saving error #187

Closed avivajpeyi closed 2 years ago

avivajpeyi commented 2 years ago
tic_entry.save_data(inference_data=inference_data)
summary(inference_data)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_156657/3890135057.py in <module>
----> 1 tic_entry.save_data(inference_data=inference_data)
      2 summary(inference_data)

/fred/oz200/avajpeyi/projects/tess-atlas/src/tess_atlas/data/tic_entry.py in save_data(self, inference_data)
    138         if inference_data is not None:
    139             self.inference_data = inference_data
--> 140             save_inference_data(inference_data, self.outdir)
    141         logger.info(f"Saved data in {self.outdir}")

/fred/oz200/avajpeyi/projects/tess-atlas/src/tess_atlas/data/inference_data_tools.py in save_inference_data(inference_data, outdir)
     30 def save_inference_data(inference_data, outdir: str):
     31     fname = get_idata_fname(outdir)
---> 32     az.to_netcdf(inference_data, filename=fname)
     33     save_samples(inference_data, outdir)
     34 

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/arviz/data/io_netcdf.py in to_netcdf(data, filename, group, coords, dims)
     50     """
     51     inference_data = convert_to_inference_data(data, group=group, coords=coords, dims=dims)
---> 52     file_name = inference_data.to_netcdf(filename)
     53     return file_name

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/arviz/data/inference_data.py in to_netcdf(self, filename, compress, groups)
    390                 if compress:
    391                     kwargs["encoding"] = {var_name: {"zlib": True} for var_name in data.variables}
--> 392                 data.to_netcdf(filename, mode=mode, group=group, **kwargs)
    393                 data.close()
    394                 mode = "a"

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/xarray/core/dataset.py in to_netcdf(self, path, mode, format, group, engine, encoding, unlimited_dims, compute, invalid_netcdf)
   1898         from ..backends.api import to_netcdf
   1899 
-> 1900         return to_netcdf(
   1901             self,
   1902             path,

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/xarray/backends/api.py in to_netcdf(dataset, path_or_file, mode, format, group, engine, encoding, unlimited_dims, compute, multifile, invalid_netcdf)
   1075         # TODO: allow this work (setting up the file for writing array data)
   1076         # to be parallelized with dask
-> 1077         dump_to_store(
   1078             dataset, store, writer, encoding=encoding, unlimited_dims=unlimited_dims
   1079         )

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/xarray/backends/api.py in dump_to_store(dataset, store, writer, encoder, encoding, unlimited_dims)
   1122         variables, attrs = encoder(variables, attrs)
   1123 
-> 1124     store.store(variables, attrs, check_encoding, writer, unlimited_dims=unlimited_dims)
   1125 
   1126 

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/xarray/backends/common.py in store(self, variables, attributes, check_encoding_set, writer, unlimited_dims)
    260             writer = ArrayWriter()
    261 
--> 262         variables, attributes = self.encode(variables, attributes)
    263 
    264         self.set_attributes(attributes)

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/xarray/backends/common.py in encode(self, variables, attributes)
    349         # All NetCDF files get CF encoded by default, without this attempting
    350         # to write times, for example, would fail.
--> 351         variables, attributes = cf_encoder(variables, attributes)
    352         variables = {k: self.encode_variable(v) for k, v in variables.items()}
    353         attributes = {k: self.encode_attribute(v) for k, v in attributes.items()}

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/xarray/conventions.py in cf_encoder(variables, attributes)
    853     _update_bounds_encoding(variables)
    854 
--> 855     new_vars = {k: encode_cf_variable(v, name=k) for k, v in variables.items()}
    856 
    857     # Remove attrs from bounds variables (issue #2921)

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/xarray/conventions.py in <dictcomp>(.0)
    853     _update_bounds_encoding(variables)
    854 
--> 855     new_vars = {k: encode_cf_variable(v, name=k) for k, v in variables.items()}
    856 
    857     # Remove attrs from bounds variables (issue #2921)

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/xarray/conventions.py in encode_cf_variable(var, needs_copy, name)
    273     var = maybe_default_fill_value(var)
    274     var = maybe_encode_bools(var)
--> 275     var = ensure_dtype_not_object(var, name=name)
    276 
    277     for attr_name in CF_RELATED_DATA:

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/xarray/conventions.py in ensure_dtype_not_object(var, name)
    231             data[missing] = fill_value
    232         else:
--> 233             data = _copy_with_dtype(data, dtype=_infer_dtype(data, name))
    234 
    235         assert data.dtype.kind != "O" or data.dtype.metadata

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/xarray/conventions.py in _infer_dtype(array, name)
    165         return dtype
    166 
--> 167     raise ValueError(
    168         "unable to infer dtype on variable {!r}; xarray "
    169         "cannot serialize arbitrary Python objects".format(name)

ValueError: unable to infer dtype on variable 'obs'; xarray cannot serialize arbitrary Python object
avivajpeyi commented 2 years ago
tic_entry.save_data(inference_data=inference_data)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/var/folders/qt/rxjvm_j566v9qn7g754s1v9hzb3p7f/T/ipykernel_10253/3890135057.py in <module>
----> 1 tic_entry.save_data(inference_data=inference_data)
      2 summary(inference_data)

~/Documents/projects/tess/tess-atlas/src/tess_atlas/data/tic_entry.py in save_data(self, inference_data)
    138         if inference_data is not None:
    139             self.inference_data = inference_data
--> 140             save_inference_data(inference_data, self.outdir)
    141         logger.info(f"Saved data in {self.outdir}")

~/Documents/projects/tess/tess-atlas/src/tess_atlas/data/inference_data_tools.py in save_inference_data(inference_data, outdir)
     32 def save_inference_data(inference_data, outdir: str):
     33     fname = get_idata_fname(outdir)
---> 34     az.to_netcdf(inference_data, filename=fname)
     35     save_samples(inference_data, outdir)
     36 

~/Documents/projects/tess/tess_venv/lib/python3.9/site-packages/arviz/data/io_netcdf.py in to_netcdf(data, filename, group, coords, dims)
     50     """
     51     inference_data = convert_to_inference_data(data, group=group, coords=coords, dims=dims)
---> 52     file_name = inference_data.to_netcdf(filename)
     53     return file_name

~/Documents/projects/tess/tess_venv/lib/python3.9/site-packages/arviz/data/inference_data.py in to_netcdf(self, filename, compress, groups)
    390                 if compress:
    391                     kwargs["encoding"] = {var_name: {"zlib": True} for var_name in data.variables}
--> 392                 data.to_netcdf(filename, mode=mode, group=group, **kwargs)
    393                 data.close()
    394                 mode = "a"

~/Documents/projects/tess/tess_venv/lib/python3.9/site-packages/xarray/core/dataset.py in to_netcdf(self, path, mode, format, group, engine, encoding, unlimited_dims, compute, invalid_netcdf)
   1898         from ..backends.api import to_netcdf
   1899 
-> 1900         return to_netcdf(
   1901             self,
   1902             path,

~/Documents/projects/tess/tess_venv/lib/python3.9/site-packages/xarray/backends/api.py in to_netcdf(dataset, path_or_file, mode, format, group, engine, encoding, unlimited_dims, compute, multifile, invalid_netcdf)
   1075         # TODO: allow this work (setting up the file for writing array data)
   1076         # to be parallelized with dask
-> 1077         dump_to_store(
   1078             dataset, store, writer, encoding=encoding, unlimited_dims=unlimited_dims
   1079         )

~/Documents/projects/tess/tess_venv/lib/python3.9/site-packages/xarray/backends/api.py in dump_to_store(dataset, store, writer, encoder, encoding, unlimited_dims)
   1122         variables, attrs = encoder(variables, attrs)
   1123 
-> 1124     store.store(variables, attrs, check_encoding, writer, unlimited_dims=unlimited_dims)
   1125 
   1126 

~/Documents/projects/tess/tess_venv/lib/python3.9/site-packages/xarray/backends/common.py in store(self, variables, attributes, check_encoding_set, writer, unlimited_dims)
    260             writer = ArrayWriter()
    261 
--> 262         variables, attributes = self.encode(variables, attributes)
    263 
    264         self.set_attributes(attributes)

~/Documents/projects/tess/tess_venv/lib/python3.9/site-packages/xarray/backends/common.py in encode(self, variables, attributes)
    349         # All NetCDF files get CF encoded by default, without this attempting
    350         # to write times, for example, would fail.
--> 351         variables, attributes = cf_encoder(variables, attributes)
    352         variables = {k: self.encode_variable(v) for k, v in variables.items()}
    353         attributes = {k: self.encode_attribute(v) for k, v in attributes.items()}

~/Documents/projects/tess/tess_venv/lib/python3.9/site-packages/xarray/conventions.py in cf_encoder(variables, attributes)
    851 
    852     # add encoding for time bounds variables if present.
--> 853     _update_bounds_encoding(variables)
    854 
    855     new_vars = {k: encode_cf_variable(v, name=k) for k, v in variables.items()}

~/Documents/projects/tess/tess_venv/lib/python3.9/site-packages/xarray/conventions.py in _update_bounds_encoding(variables)
    431         is_datetime_type = np.issubdtype(
    432             v.dtype, np.datetime64
--> 433         ) or contains_cftime_datetimes(v)
    434 
    435         if (

~/Documents/projects/tess/tess_venv/lib/python3.9/site-packages/xarray/core/common.py in contains_cftime_datetimes(var)
   1839 def contains_cftime_datetimes(var) -> bool:
   1840     """Check if an xarray.Variable contains cftime.datetime objects"""
-> 1841     return _contains_cftime_datetimes(var.data)
   1842 
   1843 

~/Documents/projects/tess/tess_venv/lib/python3.9/site-packages/xarray/core/variable.py in data(self)
    342             return self._data
    343         else:
--> 344             return self.values
    345 
    346     @data.setter

~/Documents/projects/tess/tess_venv/lib/python3.9/site-packages/xarray/core/variable.py in values(self)
    515     def values(self):
    516         """The variable's data as a numpy.ndarray"""
--> 517         return _as_array_or_item(self._data)
    518 
    519     @values.setter

~/Documents/projects/tess/tess_venv/lib/python3.9/site-packages/xarray/core/variable.py in _as_array_or_item(data)
    257     TODO: remove this (replace with np.asarray) once these issues are fixed
    258     """
--> 259     data = np.asarray(data)
    260     if data.ndim == 0:
    261         if data.dtype.kind == "M":

~/Documents/projects/tess/tess_venv/lib/python3.9/site-packages/xarray/core/indexing.py in __array__(self, dtype)
    546 
    547     def __array__(self, dtype=None):
--> 548         self._ensure_cached()
    549         return np.asarray(self.array, dtype=dtype)
    550 

~/Documents/projects/tess/tess_venv/lib/python3.9/site-packages/xarray/core/indexing.py in _ensure_cached(self)
    543     def _ensure_cached(self):
    544         if not isinstance(self.array, NumpyIndexingAdapter):
--> 545             self.array = NumpyIndexingAdapter(np.asarray(self.array))
    546 
    547     def __array__(self, dtype=None):

~/Documents/projects/tess/tess_venv/lib/python3.9/site-packages/xarray/core/indexing.py in __array__(self, dtype)
    516 
    517     def __array__(self, dtype=None):
--> 518         return np.asarray(self.array, dtype=dtype)
    519 
    520     def __getitem__(self, key):

~/Documents/projects/tess/tess_venv/lib/python3.9/site-packages/xarray/core/indexing.py in __array__(self, dtype)
    417     def __array__(self, dtype=None):
    418         array = as_indexable(self.array)
--> 419         return np.asarray(array[self.key], dtype=None)
    420 
    421     def transpose(self, order):

~/Documents/projects/tess/tess_venv/lib/python3.9/site-packages/xarray/backends/netCDF4_.py in __getitem__(self, key)
     89 
     90     def __getitem__(self, key):
---> 91         return indexing.explicit_indexing_adapter(
     92             key, self.shape, indexing.IndexingSupport.OUTER, self._getitem
     93         )

~/Documents/projects/tess/tess_venv/lib/python3.9/site-packages/xarray/core/indexing.py in explicit_indexing_adapter(key, shape, indexing_support, raw_indexing_method)
    708     """
    709     raw_key, numpy_indices = decompose_indexer(key, shape, indexing_support)
--> 710     result = raw_indexing_method(raw_key.tuple)
    711     if numpy_indices.tuple:
    712         # index the loaded np.ndarray

~/Documents/projects/tess/tess_venv/lib/python3.9/site-packages/xarray/backends/netCDF4_.py in _getitem(self, key)
    102             with self.datastore.lock:
    103                 original_array = self.get_array(needs_lock=False)
--> 104                 array = getitem(original_array, key)
    105         except IndexError:
    106             # Catch IndexError in netCDF4 and return a more informative

src/netCDF4/_netCDF4.pyx in netCDF4._netCDF4.Variable.__getitem__()

src/netCDF4/_netCDF4.pyx in netCDF4._netCDF4.Variable._get()

src/netCDF4/_netCDF4.pyx in netCDF4._netCDF4._ensure_nc_success()

RuntimeError: NetCDF: HDF error
avivajpeyi commented 2 years ago

write gp model using celerite2 testing issuue

avivajpeyi commented 2 years ago

Using a simple GP example from celerite2, I was able to save the inference trace without an issue.

Simple GP test ```python import pymc3 as pm import celerite2.theano from celerite2.theano import terms as theano_terms import celerite2 from celerite2 import terms import numpy as np import matplotlib.pyplot as plt from collections import namedtuple from corner.arviz_corner import ( _var_names, convert_to_dataset, get_coords, xarray_var_iter, ) import arviz as az from typing import List np.random.seed(42) Data = namedtuple('Data', ['t', 'y', 'yerr', 'true_t', 'true_y']) PRIOR_SIGMA = 2.0 FREQ = np.linspace(1.0 / 8, 1.0 / 0.3, 500) OMEGA = 2 * np.pi * FREQ def convert_to_numpy_list( inference_data: az.InferenceData, params: List[str] ) -> np.ndarray: dataset = convert_to_dataset(inference_data, group="posterior") var_names = _var_names(params, dataset) plotters = list( xarray_var_iter( get_coords(dataset, {}), var_names=var_names, combined=True ) ) return np.stack([x[-1].flatten() for x in plotters], axis=0) def generate_data() -> Data: t = np.sort( np.append( np.random.uniform(0, 3.8, 57), # gap between 3.8-5.5. np.random.uniform(5.5, 10, 68), ) ) # The input coordinates must be sorted yerr = np.random.uniform(0.08, 0.22, len(t)) y = ( 0.2 * (t - 5) + np.sin(3 * t + 0.1 * (t - 5) ** 2) + yerr * np.random.randn(len(t)) ) true_t = np.linspace(0, 10, 500) true_y = 0.2 * (true_t - 5) + np.sin(3 * true_t + 0.1 * (true_t - 5) ** 2) return Data(t, y, yerr, true_t, true_y) def plot_data(d: Data): plt.plot(d.true_t, d.true_y, "k", lw=1.5, alpha=0.3) plt.errorbar(d.t, d.y, yerr=d.yerr, fmt=".k", capsize=0) plt.xlabel("x [day]") plt.ylabel("y [ppm]") plt.xlim(0, 10) plt.ylim(-2.5, 2.5) plt.title("simulated data") plt.savefig("simulated_data.png") def build_initial_gp(): term1 = terms.SHOTerm(sigma=1.0, rho=1.0, tau=10.0) term2 = terms.SHOTerm(sigma=1.0, rho=5.0, Q=0.25) kernel = term1 + term2 gp = celerite2.GaussianProcess(kernel, mean=0.0) return gp def set_gp_params(params, gp, data): gp.mean = params[0] theta = np.exp(params[1:]) gp.kernel = terms.SHOTerm( sigma=theta[0], rho=theta[1], tau=theta[2] ) + terms.SHOTerm(sigma=theta[3], rho=theta[4], Q=0.25) gp.compute(data.t, diag=data.yerr ** 2 + theta[5], quiet=True) return gp def plot_posterior_pds(psds, freq, ax): q = np.percentile(psds, [16, 50, 84], axis=0) ax.loglog(freq, q[1], color="C0") ax.fill_between(freq, q[0], q[2], color="C0", alpha=0.1) ax.set_xlim(freq.min(), freq.max()) ax.set_xlabel("frequency [1 / day]") ax.set_ylabel("power [day ppt$^2$]") ax.set_title("posterior psd") def plot_posterior_prediction(data, gp, ax, samples=[]): ax.plot(data.true_t, data.true_y, "k", lw=1.5, alpha=0.3, label="data") ax.errorbar(data.t, data.y, yerr=data.yerr, fmt=".k", capsize=0, label="truth") gps = [] for sample in samples: gp = set_gp_params(sample, gp, data) conditional = gp.condition(data.y, data.true_t) gps.append(conditional.sample()) ax.plot(data.true_t, conditional.sample(), color="C0", alpha=0.1) q = np.percentile(gps, [16, 50, 84], axis=0) ax.fill_between(data.true_t, q[0], q[2], color="C0", alpha=0.1) ax.plot(data.true_t, q[1], label="Prediction") ax.set_xlabel("x [day]") ax.set_ylabel("y [ppm]") ax.set_xlim(0, 10) ax.set_ylim(-2.5, 2.5) ax.set_title("posterior prediction") ax.legend() def plot_posterior(trace, freq, data, gp): fig, axes = plt.subplots(1, 2, figsize=(10, 4)) psd = convert_to_numpy_list(trace.posterior, ["psd"]).T samples = convert_to_numpy_list(trace.posterior, ["mean", "log_sigma1", "log_rho1", "log_tau", "log_sigma2", "log_rho2", "log_jitter"]) plot_posterior_pds(psd, freq, axes[0]) plot_posterior_prediction(data, gp, axes[1], samples) plt.savefig("posterior_result.") def run_inference(data, prior_sigma, omega): with pm.Model() as model: mean = pm.Normal("mean", mu=0.0, sigma=prior_sigma) log_jitter = pm.Normal("log_jitter", mu=0.0, sigma=prior_sigma) log_sigma1 = pm.Normal("log_sigma1", mu=0.0, sigma=prior_sigma) log_rho1 = pm.Normal("log_rho1", mu=0.0, sigma=prior_sigma) log_tau = pm.Normal("log_tau", mu=0.0, sigma=prior_sigma) term1 = theano_terms.SHOTerm( sigma=pm.math.exp(log_sigma1), rho=pm.math.exp(log_rho1), tau=pm.math.exp(log_tau), ) log_sigma2 = pm.Normal("log_sigma2", mu=0.0, sigma=prior_sigma) log_rho2 = pm.Normal("log_rho2", mu=0.0, sigma=prior_sigma) term2 = theano_terms.SHOTerm( sigma=pm.math.exp(log_sigma2), rho=pm.math.exp(log_rho2), Q=0.25 ) kernel = term1 + term2 gp = celerite2.theano.GaussianProcess(kernel, mean=mean) gp.compute(data.t, diag=data.yerr ** 2 + pm.math.exp(log_jitter), quiet=True) gp.marginal("obs", observed=data.y) pm.Deterministic("psd", kernel.get_psd(omega)) trace = pm.sample( tune=1000, draws=1000, target_accept=0.9, init="adapt_full", cores=2, chains=2, random_seed=34923, return_inferencedata=True ) return trace def main(): data = generate_data() gp = build_initial_gp() trace = run_inference(data, PRIOR_SIGMA, OMEGA) plot_posterior(trace, FREQ, data, gp) az.to_netcdf(trace, filename="inference.netcdf") main() ```
avivajpeyi commented 2 years ago

I tried running again with our transit model, and got the the xarray cannot serialize arbitrary Python objects error...

I guess one of the variables that in my pymc3 reference object is an object -- need to figure out where this is happening.

avivajpeyi commented 2 years ago

The inference_data.obsereved_data.obs is what is causing the error!!

Its storing an object rather than an array...

>>> inference_data.observed_data.obs
<xarray.DataArray 'obs' (obs_dim_0: 1)>
array([Elemwise{sub,no_inplace}.0], dtype=object)
Coordinates:
  * obs_dim_0  (obs_dim_0) int64 0

Previously:

>>> inference_data.observed_data.obs
<xarray.DataArray 'obs' (obs_dim_0: 37177)>
array([-4.861564, -1.387775,  1.397878, ...,  5.90694 , 12.050734,  9.48945 ])
Coordinates:
  * obs_dim_0  (obs_dim_0) int64 0 1 2 3 4 5 ... 37172 37173 37174 37175 37176

Which commit changed this?

Goes to show that we really need better unit tests...

avivajpeyi commented 2 years ago

Aha the 'fix gp' merge is what caused this...

https://github.com/dfm/tess-atlas/commit/bfcf880c24bd0c8b1dfc71d1aca46b1f2508d235

However, we still need to use the changes made here...

dfm commented 2 years ago

I'm a little surprised about this - are you sure? I would have expected the previous version to fail with this error, not the version after that commit?!?

avivajpeyi commented 2 years ago

Hmm yeah I've probably done something silly, but I think the issue lies with changing the following line from

gp.marginal(name="obs", observed=y) ## calling this model a

to

gp.marginal(name="obs", observed=residual)  ## calling this model b

I tried sampling with the two different models and model A's obs array contains floats, while model B'sobs array contains an object

Screen Shot 2022-03-04 at 12 54 41 pm Screen Shot 2022-03-04 at 12 54 50 pm

Attaching the notebook with the two models for future reference. model_comparisons.ipynb.zip

Im adding some prints in the model to figure out why obs is storing an object rather than float values.

avivajpeyi commented 2 years ago

Maybe this could be because

avivajpeyi commented 2 years ago

Ok, so residual being a TensorVar is the issue. Was able to reproduce this error with a simple line model:

import pymc3 as pm
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(42)

true_m = 0.5
true_b = -1.3
true_logs = np.log(0.3)

x = np.sort(np.random.uniform(0, 5, 50))
y = true_b + true_m * x + np.exp(true_logs) * np.random.randn(len(x))

N = 1000
sampler_kwargs = dict(
    draws=N, tune=N, chains=1, cores=1, return_inferencedata=True, progressbar=False
)

def sample_line_model_a():
    with pm.Model() as model:
        m = pm.Uniform("m", lower=-5, upper=5)
        b = pm.Uniform("b", lower=-5, upper=5)
        logs = pm.Uniform("logs", lower=-5, upper=5)
        line = m * x + b
        pm.Normal("obs", mu=line, sd=pm.math.exp(logs), observed=y)
        return pm.sample(**sampler_kwargs)

def sample_line_model_b():
    with pm.Model() as model:
        m = pm.Uniform("m", lower=-5, upper=5)
        b = pm.Uniform("b", lower=-5, upper=5)
        logs = pm.Uniform("logs", lower=-5, upper=5)
        line = m * x + b
        residual = y-line
        pm.Normal("obs", mu=0, sd=pm.math.exp(logs), observed=residual)
        return pm.sample(**sampler_kwargs)

Again, here

avivajpeyi commented 2 years ago

Maybe the way to proceed is to just not save the obs array? We don't really use it/need it

dfm commented 2 years ago

Very interesting - good sleuthing!!