Bayer-Group / pybalance

A library for minimizing the effects of confounding covariates
BSD 3-Clause "New" or "Revised" License
11 stars 0 forks source link

Balance calculators should support weighted balance calculations #27

Open sprivite opened 1 week ago

sprivite commented 1 week ago

We need to start supporting weighting models and the first step for that is to include weighting in the BalanceCalculator. I'm struggling to find a really compelling interface here.

As a recap, when we started, we built pybalance assuming a fixed target population and we were matching pool --> target.

Back then we had:

m = generate_toy_dataset()
target, pool = split_target_pool(m)
gamma = GammaBalance(m)
gamma.distance(pool)

Then we dropped the assumption of a fixed target, and added an extra argument to the distance method, so now it's like this:

gamma.distance(pool, target)

Unfortunately, pool and target are switched with respect to everywhere else in the code, where target comes first. This will have to be changed in the 1.0 release. It is what is it. Sorry :(.

I also find that the input datatypes get very confusing. We can input DataFrames, but we can also input index lists, e.g.,

pool_ix=[0,0,1,0,1,0,1,...,0,0] # len(pool)
target_ix=[0,0,1,0,1,0,1,...,0,0] # len(target)
gamma.distance(pool_ix, target_ix)

Do we allow mixed datatypes?

gamma.distance(pool, target_ix)

It's not so hard so far, but remember we also allow multidimensional indices! So ...

gamma.distance([pool_ix1, pool_ix2], [target_ix1, target_ix2])

which is operationally (but not necessarily computationally) equivalent to:

[gamma.distance(pool_ix1, target_ix1), gamma.distance(pool_ix2, target_ix2)]

The importance of this form is that we can compute the balances on a GPU and thereby compute the balance for 1000 populations at once if we want. So all this complexity comes from wanting to support the EA solver.

Even so I think the current implementation handles this pretty well.

Now we want to add weights to the mix. Naively, I would say, we add two more arguments, pool_weights and target_weights.

gamma.distance(pool, target, pool_weights, target_weights)

If pool and target are DataFrames or 1D indices, then pool_weights and target_weights must be 1D arrays. If pool / target are multidimensional index arrays, then pool_weights and target_weights must also be multidimensional agree have the same first dimension.

This is the best I can come up with right now, at least for something that will preserve backward functionality. I can't put my finger on why, but I'm not so so happy with this. I think for now it's the way forward, but I'd also like to think about 1.0: What would a better interface look like?