Valentyn1997 / CausalTransformer

Code for the paper "Causal Transformer for Estimating Counterfactual Outcomes"
MIT License
95 stars 21 forks source link

Question about the MSM config parameters #6

Closed lishuang1206 closed 1 year ago

lishuang1206 commented 1 year ago

Hi, thank you so much for your codes, this is a really great work. May I ask about how to run the MSM? Because when I run the train_msm.py, it requires the dataset.treatment_model, so I use "multilabel". When it calculates the _testrmses for _test_cf_treatmentseq, the function _get_normalised_n_steprmses requires the _modeltype=='decoder'/'g-net'/'mutli', while here the value of _modeltype is _"msmregressor", if I need change the value? which one should I use? Here is the corresponding codes. test_rmses = {} if hasattr(dataset_collection, 'test_cf_treatment_seq'): # Test n_step_counterfactual rmse test_rmses = msm_regressor.get_normalised_n_step_rmses(dataset_collection.test_cf_treatment_seq) def get_normalised_n_step_rmses(self, dataset: Dataset, datasets_mc: List[Dataset] = None): logger.info(f'RMSE calculation for {dataset.subset_name}.') assert self.model_type == 'decoder' or self.model_type == 'multi' or self.model_type == 'g_net' Thank you

Valentyn1997 commented 1 year ago

Hi! Thanks for raising this issue, this is really a bug. I pushed a new commit, which adds ... or self.model_type == 'msm_regressor to the assertion check.

Best, Valentyn