microsoft / responsible-ai-toolbox-mitigations

Python library for implementing Responsible AI mitigations.
https://responsible-ai-toolbox-mitigations.readthedocs.io/en/latest/
MIT License
57 stars 6 forks source link

Case3.ipynb: Invalid column name `sick-euthyroid` #31

Closed morrissharp closed 2 years ago

morrissharp commented 2 years ago

When running the CTGAN section of case3.ipynb in 5 - Synthetic Data, I receive the following error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
c:\Users\morrissharp\Repos\responsible-ai-toolbox-mitigations\notebooks\dataprocessing\case_study\case3.ipynb Cell 45 in <cell line: 11>()
      [8](vscode-notebook-cell:/c%3A/Users/morrissharp/Repos/responsible-ai-toolbox-mitigations/notebooks/dataprocessing/case_study/case3.ipynb#ch0000041?line=7) synth.fit()
     [10](vscode-notebook-cell:/c%3A/Users/morrissharp/Repos/responsible-ai-toolbox-mitigations/notebooks/dataprocessing/case_study/case3.ipynb#ch0000041?line=9) conditions = {label_col:1}    # create more of the undersampled class
---> [11](vscode-notebook-cell:/c%3A/Users/morrissharp/Repos/responsible-ai-toolbox-mitigations/notebooks/dataprocessing/case_study/case3.ipynb#ch0000041?line=10) syn_train_x, syn_train_y = synth.transform(X=train_x_sel, y=train_y, n_samples=200, conditions=conditions)
     [13](vscode-notebook-cell:/c%3A/Users/morrissharp/Repos/responsible-ai-toolbox-mitigations/notebooks/dataprocessing/case_study/case3.ipynb#ch0000041?line=12) syn_train_y.value_counts()

File c:\users\morrissharp\repos\responsible-ai-toolbox-mitigations\raimitigations\dataprocessing\sampler\synthesizer.py:570, in Synthesizer.transform(self, df, X, y, n_samples, conditions, strategy)
    568 if n_samples is not None:
    569     print(df.columns, conditions)
--> 570     samples = self.model.sample(n_samples, conditions=conditions)
    571 else:
    572     samples = self._generate_samples_strategy(df, strategy)

File c:\Users\morrissharp\Miniconda3\envs\rai\lib\site-packages\sdv\tabular\base.py:451, in BaseTabularModel.sample(self, num_rows, max_retries, max_rows_multiplier, conditions, float_rtol, graceful_reject_sampling)
    449 for column in conditions.columns:
    450     if column not in self._metadata.get_fields():
--> 451         raise ValueError(f'Invalid column name `{column}`')
    453 try:
    454     transformed_conditions = self._metadata.transform(conditions, on_missing_column='drop')

ValueError: Invalid column name `sick-euthyroid`

I am not sure what is going on. sick-euthyroid appears to be the name of the pandas Series that is passed in (train_y)