dfm / tess-atlas

MIT License
9 stars 8 forks source link

Theano pickling error #266

Closed avivajpeyi closed 1 year ago

avivajpeyi commented 1 year ago

Got 10/~3k of these

planet_transit_model, params = build_planet_transit_model(tic_entry)
model_varnames = get_untransformed_varnames(planet_transit_model)
test_model(planet_transit_model)
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-1-797d677cee95> in <module>
----> 1 planet_transit_model, params = build_planet_transit_model(tic_entry)
      2 model_varnames = get_untransformed_varnames(planet_transit_model)
      3 test_model(planet_transit_model)

<ipython-input-1-44cb72db158f> in build_planet_transit_model(tic_entry)
    102                     upper=planet.tmax + planet.duration_max,
    103                 )
--> 104                 tmax_prior = tmax_norm(
    105                     name=f"{TIME_END}_{planet.index}",
    106                     mu=planet.tmax,

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/pymc3/distributions/bound.py in __call__(self, name, *args, **kwargs)
    293         transform = kwargs.pop("transform", "infer")
    294         if issubclass(self.distribution, Continuous):
--> 295             return _ContinuousBounded(
    296                 name, self.distribution, self.lower, self.upper, transform, *args, **kwargs
    297             )

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/pymc3/distributions/distribution.py in __new__(cls, name, *args, **kwargs)
    120         else:
    121             dist = cls.dist(*args, **kwargs)
--> 122         return model.Var(name, dist, data, total_size, dims=dims)
    123 
    124     def __getnewargs__(self):

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/pymc3/model.py in Var(self, name, dist, data, total_size, dims)
   1140             else:
   1141                 with self:
-> 1142                     var = TransformedRV(
   1143                         name=name,
   1144                         distribution=dist,

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/pymc3/model.py in __init__(self, type, owner, index, name, distribution, model, transform, total_size)
   2011 
   2012             self.transformed = model.Var(
-> 2013                 transformed_name, transform.apply(distribution), total_size=total_size
   2014             )
   2015 

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/pymc3/distributions/transforms.py in apply(self, dist)
    124     def apply(self, dist):
    125         # avoid circular import
--> 126         return TransformedDistribution.dist(dist, self)
    127 
    128     def __str__(self):

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/pymc3/distributions/distribution.py in dist(cls, *args, **kwargs)
    128     def dist(cls, *args, **kwargs):
    129         dist = object.__new__(cls)
--> 130         dist.__init__(*args, **kwargs)
    131         return dist
    132 

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/pymc3/distributions/transforms.py in __init__(self, dist, transform, *args, **kwargs)
    152         self.dist = dist
    153         self.transform_used = transform
--> 154         v = forward(FreeRV(name="v", distribution=dist))
    155         self.type = v.type
    156 

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/pymc3/model.py in __init__(self, type, owner, index, name, distribution, total_size, model)
   1672             # The logp might need scaling in minibatches.
   1673             # This is done in `Factor`.
-> 1674             self.logp_sum_unscaledt = distribution.logp_sum(self)
   1675             self.logp_nojac_unscaledt = distribution.logp_nojac(self)
   1676             self.total_size = total_size

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/pymc3/distributions/distribution.py in logp_sum(self, *args, **kwargs)
    265         if only the sum of the logp values is needed.
    266         """
--> 267         return tt.sum(self.logp(*args, **kwargs))
    268 
    269     __latex__ = _repr_latex_

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/theano/tensor/basic.py in sum(input, axis, dtype, keepdims, acc_dtype)
   3219     """
   3220 
-> 3221     out = elemwise.Sum(axis=axis, dtype=dtype, acc_dtype=acc_dtype)(input)
   3222 
   3223     if keepdims:

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/theano/graph/op.py in __call__(self, *inputs, **kwargs)
    251 
    252         if config.compute_test_value != "off":
--> 253             compute_test_value(node)
    254 
    255         if self.default_output is not None:

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/theano/graph/op.py in compute_test_value(node)
    124 
    125     # Create a thunk that performs the computation
--> 126     thunk = node.op.make_thunk(node, storage_map, compute_map, no_recycling=])
    127     thunk.inputs = storage_mapv] for v in node.inputs]
    128     thunk.outputs = storage_mapv] for v in node.outputs]

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/theano/graph/op.py in make_thunk(self, node, storage_map, compute_map, no_recycling, impl)
    632             )
    633             try:
--> 634                 return self.make_c_thunk(node, storage_map, compute_map, no_recycling)
    635             except (NotImplementedError, MethodNotDefined):
    636                 # We requested the c code, so don't catch the error.

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/theano/graph/op.py in make_c_thunk(self, node, storage_map, compute_map, no_recycling)
    598                 print(f"Disabling C code for {self} due to unsupported float16")
    599                 raise NotImplementedError("float16")
--> 600         outputs = cl.make_thunk(
    601             input_storage=node_input_storage, output_storage=node_output_storage
    602         )

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/theano/link/c/basic.py in make_thunk(self, input_storage, output_storage, storage_map)
   1201         """
   1202         init_tasks, tasks = self.get_init_tasks()
-> 1203         cthunk, module, in_storage, out_storage, error_storage = self.__compile__(
   1204             input_storage, output_storage, storage_map
   1205         )

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/theano/link/c/basic.py in __compile__(self, input_storage, output_storage, storage_map)
   1136         input_storage = tuple(input_storage)
   1137         output_storage = tuple(output_storage)
-> 1138         thunk, module = self.cthunk_factory(
   1139             error_storage,
   1140             input_storage,

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/theano/link/c/basic.py in cthunk_factory(self, error_storage, in_storage, out_storage, storage_map)
   1632             for node in self.node_order:
   1633                 node.op.prepare_node(node, storage_map, None, "c")
-> 1634             module = get_module_cache().module_from_key(key=key, lnk=self)
   1635 
   1636         vars = self.inputs + self.outputs + self.orphans

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/theano/link/c/cmodule.py in module_from_key(self, key, lnk)
   1155         # Is the source code already in the cache?
   1156         module_hash = get_module_hash(src_code, key)
-> 1157         module = self._get_from_hash(module_hash, key)
   1158         if module is not None:
   1159             return module

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/theano/link/c/cmodule.py in _get_from_hash(self, module_hash, key)
   1058             with lock_ctx():
   1059                 try:
-> 1060                     key_data.add_key(key, save_pkl=bool(key0]))
   1061                     key_broken = False
   1062                 except pickle.PicklingError:

/fred/oz200/avajpeyi/envs/tess/lib/python3.8/site-packages/theano/link/c/cmodule.py in add_key(self, key, save_pkl)
    495 
    496         """
--> 497         assert key not in self.keys
    498         self.keys.add(key)
    499         if save_pkl:

AssertionError: 
avivajpeyi commented 1 year ago

Fixed with 121b4ce47d32b23764e287b38992abacd8c2f3c3