joshjchayes / TransitFit

Transit light curve fitting using nested sampling
GNU General Public License v3.0
4 stars 1 forks source link

fix for broadcast error #11

Closed sourestdeeds closed 1 year ago

sourestdeeds commented 3 years ago

Seems to fix the broadcast error with n_procs > 1. After printing beforehand it already seems to be a numpy array. Not sure if the np array conversion is needed for the others, but it works in this form.

Traceback for the associated error:

Traceback (most recent call last): File "/usr/local/anaconda3/lib/python3.8/site-packages/firefly-0.7.9-py3.8.egg/firefly/auto_retrieval.py", line 337, in firefly _retrieval( File "/usr/local/anaconda3/lib/python3.8/site-packages/firefly-0.7.9-py3.8.egg/firefly/_utils.py", line 392, in _retrieval run_retrieval( File "/usr/local/anaconda3/lib/python3.8/site-packages/transitfit/_pipeline.py", line 321, in run_retrieval results = retriever.run_retrieval(ld_fit_method, fitting_mode, File "/usr/local/anaconda3/lib/python3.8/site-packages/transitfit/retriever.py", line 661, in run_retrieval results = self._run_folded_retrieval(ld_fit_method, detrend, normalise, File "/usr/local/anaconda3/lib/python3.8/site-packages/transitfit/retriever.py", line 474, in _run_folded_retrieval results, priors, lightcurves = self._run_batched_retrieval(self.all_lightcurves, File "/usr/local/anaconda3/lib/python3.8/site-packages/transitfit/retriever.py", line 419, in _run_batched_retrieval all_lightcurves = np.array([r[2] for r in batch_run_results]) ValueError: could not broadcast input array from shape (1,6) into shape (1,1)

sourestdeeds commented 3 years ago

Im thinking the solution lies more along the lines of:

with mp.Pool(processes=n_procs) as pool:
    batch_run_results = pool.map(_run_batch, mp_input)

all_results = np.array([r[0] for r in batch_run_results])
all_priors = np.array([r[1] for r in batch_run_results])
all_lightcurves = np.array([np.append(np.array([]), r[2]) for r in batch_run_results])