Hi, thank you so much for your codes, this is a really great work.
May I ask about how to run the MSM? Because when I run the train_msm.py, it requires the dataset.treatment_model, so I use "multilabel". When it calculates the _testrmses for _test_cf_treatmentseq, the function _get_normalised_n_steprmses requires the _modeltype=='decoder'/'g-net'/'mutli', while here the value of _modeltype is _"msmregressor", if I need change the value? which one should I use?
Here is the corresponding codes.
test_rmses = {} if hasattr(dataset_collection, 'test_cf_treatment_seq'): # Test n_step_counterfactual rmse test_rmses = msm_regressor.get_normalised_n_step_rmses(dataset_collection.test_cf_treatment_seq)def get_normalised_n_step_rmses(self, dataset: Dataset, datasets_mc: List[Dataset] = None): logger.info(f'RMSE calculation for {dataset.subset_name}.') assert self.model_type == 'decoder' or self.model_type == 'multi' or self.model_type == 'g_net'
Thank you
Hi!
Thanks for raising this issue, this is really a bug.
I pushed a new commit, which adds ... or self.model_type == 'msm_regressor to the assertion check.
Hi, thank you so much for your codes, this is a really great work. May I ask about how to run the MSM? Because when I run the train_msm.py, it requires the dataset.treatment_model, so I use "multilabel". When it calculates the _testrmses for _test_cf_treatmentseq, the function _get_normalised_n_steprmses requires the _modeltype=='decoder'/'g-net'/'mutli', while here the value of _modeltype is _"msmregressor", if I need change the value? which one should I use? Here is the corresponding codes.
test_rmses = {} if hasattr(dataset_collection, 'test_cf_treatment_seq'): # Test n_step_counterfactual rmse test_rmses = msm_regressor.get_normalised_n_step_rmses(dataset_collection.test_cf_treatment_seq)
def get_normalised_n_step_rmses(self, dataset: Dataset, datasets_mc: List[Dataset] = None): logger.info(f'RMSE calculation for {dataset.subset_name}.') assert self.model_type == 'decoder' or self.model_type == 'multi' or self.model_type == 'g_net'
Thank you