Closed GalaxyL777 closed 1 year ago
Hi, thanks for spotting this problem. The issue was created after a pull request of a branch that includes some optimization for the GPU. I have made a commit to fix the issue. The problem was mostly due to this line where we were flattening the matrix of samples before providing them as an input to log_rate_PE
. I have tried to run with the tutorial notebooks and everything should be ok, please let me know if you experience any other issue
I discovered an error while using your program, ICAROGW. I suspect it may be due to a recent update because the previous version did not encounter any issues when independently using the GW170817 gravitational wave event to infer cosmological parameters. However, the latest version appears to have introduced this error. (Specifically, I ran the program using the provided example in the documentation, which aims to constrain the Hubble constant using the GW170817 event.)
After reviewing the code, I believe the error originates from the CBC_vanilla_EM_counterpart class in wrappers.py. This class is designed to support multiple bright sirens but encounters an issue when dealing with a single event. In the case of a single event, the input for the posterior samples from GW170817, kwargs['mass_1'], is a list [m1_1, m1_2, m1_3, ..., m1_N] where len(kwargs['mass_1'].shape) = 1. However, when there are multiple events, kwargs['mass_1'] becomes a list of lists representing individual event samples, resulting in len(kwargs['mass_1'].shape) = 2. The error occurs because the code has a condition (from line 316 to line 317 in wrappers.py file) that assumes len(kwargs['mass_1'].shape) must be 2, which is not met in the single event case.
'''python if len(kwargs['mass_1'].shape) != 2: raise ValueError('The EM counterpart rate wants N_ev x N_samples arrays') ''' This error can be avoided by optimizing the data format, or simplely replace the log_rate_PE function of CBC_vanilla_EM_counterpart class with the following code.
'''python def log_rate_PE(self,prior,**kwargs): xp = get_module_array(prior) sx = get_module_array_scipy(prior)
'''