jakobrunge / tigramite

Tigramite is a python package for causal inference with a focus on time series data. The Tigramite documentation is at
https://jakobrunge.github.io/tigramite/
GNU General Public License v3.0
1.24k stars 270 forks source link

CausalMediation.predict_ on dataframes with NaNs #404

Open emfons opened 2 months ago

emfons commented 2 months ago

Hi Tigramite team! I like the new CausalMediation class, however I am having trouble using it on data that have NaNs.. It works fine when I drop all the rows that contain any NaN though. I do want to keep the NaNs in my case, but the evaluation of .predict_natural_direct_effect() fails with an IndexError message wrt to this line in MixedData._FixZeroDim: "categorical_data = np.empty([0, np.shape(continuous_data)[1]], dtype="int32")". That's because continuous_data is a list of arrays that DO NOT have identical sizes due to the presence of NaNs. Is there a way to fix this? Thanks!

martin-rabel commented 2 months ago

Hi,

Thanks for your interest in tigramite and its mediation-functionality in particular! Sorry for the delay, I was on vacations last week, and Jakob's initial reply seems to have gone missing somewhere on the way.

I could reproduce the problem using a tigramite dataframe initialized with "missing_flag" values, I assume that is where the issue arises in your case as well? Generally, the mediation code is a rather new addition, and has not yet been tested extensively with missing data or masking, but the issue should be relatively easy to fix (I hope).

When using the interface as implemented in the CausalMediation class (as you seem to be doing), missing data and masking are handled by the tigramite dataframe implementation. However, mediation effects usually require multiple fits and I think the issue is that the data-lookups are not "synchronized" (in the sense, that they might get different, respectively valid, samples).

This should be fixed from our side, to provide the general missing-value and masking capabilities already implemented in tigramite also for the mediation-code. I don't think this should take too long and I'll keep this thread updated.

There is in principle another approach of not setting the "missing_flag" on the data-frame and handling missing values "per fit" instead (in the fitter). This is sometimes more sample-efficient (in theory), and since the fitting-module is extensible by design (see appendix B of the meditation-tutorial) it can in principle be added "from the outside", but I'll also look into this.

This should probably be documented in the tutorial as well.

Best, Martin

emfons commented 2 months ago

Hi Martin, I'm glad you were able to reproduce the issue. From what you describe, this is exactly the same issue as the one I'm experiencing. Looking forward to trying it once the fix is implemented!

martin-rabel commented 2 months ago

Hi, The second option seems to be a rather simple fix, but does not immediately provide the other functionality of the tigramite dataframe (like masking).

The following code uses the extension point as described in appendix B of the tutorial to filter out samples flagged as "missing":

class FitProvider_Continous_Filtered:
    def __init__(self, underlying_fit_provider, filter_to_use):
        self.filter = filter_to_use
        self.underlying_fit_provider = underlying_fit_provider
    def Get_Fit_Continuous(self,x_train,y_train):
        return self.underlying_fit_provider.Get_Fit_Continuous(*self.filter.apply(x_train,y_train))

class FitProvider_Density_Filtered:
    def __init__(self, underlying_fit_provider, filter_to_use):
        self.filter = filter_to_use
        self.underlying_fit_provider = underlying_fit_provider
    def Get_Fit_Density(self, x_train):
        return self.underlying_fit_provider.Get_Fit_Density(self.filter.apply(x_train))

class FilterMissingValues:
    def __init__(self, missing_value_flag):
        self.missing_value_flag = missing_value_flag
    def apply(self,x,y=None):
        missing_in_any_x = np.any( x==self.missing_value_flag, axis=1 )
        if y is None:
            valid = np.logical_not( missing_in_any_x )
            return x[valid]
        else:
            missing_in_y = ( y==self.missing_value_flag )
            valid = np.logical_not( np.logical_or(missing_in_any_x, missing_in_y) )
            return x[valid], y[valid]

def apply_filter_to_all_inputs(fit_setup, filter_to_apply):
    # Assume the fit_setup can be contructed from map & density fit and has corresponding members
    # (for all implementations based on the FitSetup class in the mediation-module
    #  of tigramite this is the case; see tutorial on mediation, appendix B)
    return fit_setup.__class__(
        fit_map=FitProvider_Continous_Filtered(fit_setup.fit_map, filter_to_apply),
        fit_density=FitProvider_Density_Filtered(fit_setup.fit_density, filter_to_apply),
    )

Then, using this, replacing the fit configuration and the dataframe as

fit_setup2 = apply_filter_to_all_inputs(fit_setup, FilterMissingValues(special_value))
dataframe_missing = pp.DataFrame(data, var_names=var_names) #missing_flag=special_value)

fixed the problem for me. [Consider code-license as if shipped with tigramite.]

It would be nice to hear if this works for your case as well. The tutorial and tigramite dataframe based fix will probably come with the next update(s), I have to look at the details there.

Best, Martin