judithabk6 / med_bench

BSD 3-Clause "New" or "Revised" License
8 stars 3 forks source link

m.astype(int) for multiply_robust #41

Open sami6mz opened 1 year ago

sami6mz commented 1 year ago

In get_estimation.py, get_estimation() :

The mediator is forced to be binary with the line m.astype(int) :

https://github.com/judithabk6/med_bench/blob/13c97dadd3209f1fdb6e8b845225a88d89f0e6e0/src/get_estimation.py#L479C19-L479C19

I'm not sure it's a good idea since the user is not warned. Plus #35 already returns an error if m is not binary.

bthirion commented 1 year ago

But you indeed want to enforce this behavior. It just needs a warning, no ?

judithabk6 commented 1 year ago

yes. we need a global refactor of checking the input format for all estimators. The latest ones implemented by @sami6mz are good, but not the old ones, and we probably want a general function to do it, in the spirit of https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/utils/validation.py ?

sami6mz commented 1 year ago

But you indeed want to enforce this behavior. It just needs a warning, no ?

Yes but we don't want this behavior to be in get_estimation(). get_estimation() should only call the estimator without any hidden behavior. The enforcment should be in the estimator itself, and like @judithabk6 said probably in a general function.