microsoft / causica

MIT License
430 stars 56 forks source link

CATE: conditioning on categorical variable #27

Closed DailiZhang2010 closed 1 year ago

DailiZhang2010 commented 1 year ago

Thanks for the awesome library. However, when I tried to get CATE conditioning on a categorical variable, it keeps throwing errors. Can some one help provide an example of the inputs for the model.CATE() with conditioning on a categorical variable? The variable gender_enc can take 3 values. I tried conditioning_values=conditioning_values.reshape((1,3)) or .reshape((3,1)), or reshape(-1). It either ran to the error 1 or error 2 (please see the end)

`# conditioning
conditioning_idxs = dataset.variables.name_to_idx['gender_enc']
conditioning_idxs = np.array([conditioning_idxs])
conditioning_values=model.data_processor.process_data_subset_by_group(np.array([[0]]),conditioning_idxs)
print(conditioning_values.shape)
# conditioning_values=conditioning_values.reshape((1,3))
# conditioning_values=np.array([1])
print(conditioning_values)
print(conditioning_values.shape)

outcome_cols='retention'
outcome_idxs=dataset.variables.name_to_idx[outcome_cols]
effect_idxs=np.array([outcome_idxs])

model.cate_rff_n_features = 100

for treatment in ['seasonaldecor_purchase']:
    treatment_idxs = dataset.variables.name_to_idx[treatment]
    intervention_idxs=np.array([treatment_idxs])
    intervention_values=model.data_processor.process_data_subset_by_group(np.array([nodes_sel_val[treatment][1]]),intervention_idxs)
    reference_values=model.data_processor.process_data_subset_by_group(np.array([nodes_sel_val[treatment][1]]),intervention_idxs)

    ate = model.cate(
        intervention_idxs=intervention_idxs,
        intervention_values=intervention_values,
        reference_values=reference_values,
        effect_idxs=effect_idxs,
        conditioning_idxs=conditioning_idxs,
        conditioning_values=conditioning_values,
        Nsamples_per_graph=100,
        Ngraphs=1,
        most_likely_graph=True,
    )
    causica_estimated_ate[treatment] = ate[0][0]
    print(f"{treatment}: {causica_estimated_ate[treatment]}")`

Error1

image

Error2

image

Thanks a lot. Regards, Daili

agrinh commented 1 year ago

@DailiZhang2010 thanks for submitting the issue. This should be fixed since #29. Please let us know if you have any further issues.