vanderschaarlab / synthcity

A library for generating and evaluating synthetic tabular data for privacy, fairness and data augmentation.
https://www.vanderschaar-lab.com/
Apache License 2.0
432 stars 57 forks source link

Error when attempting to Benchmark the model for survival analysis #282

Open uclrmhigid opened 2 months ago

uclrmhigid commented 2 months ago

## Description I am attempting to generate synthetic data conditional on ethnicity for my survival data. I am able to generate the data but agetting an error regarding time_to_event when attempting to Benchmark the model. I am getting the error "ValueError: The time_to_event_column contains 1 values less than or equal to zero. Please remove them." when running Benchmarks.evaluate

## How to Reproduce import sys import numpy as np import warnings from synthcity.plugins import Plugins from synthcity.plugins.core.dataloader import SurvivalAnalysisDataLoader from synthcity.utils.serialization import load, save from synthcity.utils.serialization import load, load_from_file, save, save_to_file from synthcity.benchmark import Benchmarks

Set up logging and filter warnings

log.add(sink=sys.stderr, level="INFO") warnings.filterwarnings("ignore")

loader = SurvivalAnalysisDataLoader( subset_df, target_column="event_cmp", time_to_event_column="tstop", )

syn_model = Plugins().get("survival_gan") cond = subset_df["Race=Asian or Pacific Islander"] syn_model.fit(loader, cond=cond) count = 10 syn_model.generate(count=count, cond=np.ones(count)).dataframe() buff = save(synmodel) type(buff) reloaded = load(buff) reloaded.name() score = Benchmarks.evaluate( [(f"test{model}", model, {}) for model in ["adsgan", "survival_gan", "survae"]], loader, synthetic_size=1000, repeats=2, task_type="survival_analysis", )

## Expected Behavior Score of quality of the plugin

## Screenshots Get error: { "name": "ValueError", "message": "The time_to_eventcolumn contains 1 values less than or equal to zero. Please remove them.", "stack": "--------------------------------------------------------------------------- ValueError Traceback (most recent call last) Input In [83], in <cell line: 5>() 1 # synthcity absolute 2 #Can't get to work 3 from synthcity.benchmark import Benchmarks ----> 5 score = Benchmarks.evaluate( 6 [(f\"test{model}\", model, {}) for model in [\"adsgan\", \"survival_gan\", \"survae\"]], 7 loader, 8 synthetic_size=1000, 9 repeats=1, 10 task_type=\"survival_analysis\", 11 )

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\pydantic\decorator.py:40, in pydantic.decorator.validate_arguments.validate.wrapper_function()

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\pydantic\decorator.py:134, in pydantic.decorator.ValidatedFunction.call()

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\pydantic\decorator.py:206, in pydantic.decorator.ValidatedFunction.execute()

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\synthcity\benchmark\init.py:288, in Benchmarks.evaluate(tests, X, X_test, metrics, repeats, synthetic_size, synthetic_constraints, synthetic_cache, synthetic_reuse_if_exists, augmented_reuse_if_exists, task_type, workspace, augmentation_rule, strict_augmentation, ad_hoc_augment_vals, use_metric_cache, **generate_kwargs) 286 else: 287 X_augmented = None --> 288 evaluation = Metrics.evaluate( 289 X_test if X_test is not None else X.test(), 290 X_syn, 291 X.train(), 292 X_ref_syn, 293 X_augmented, 294 metrics=metrics, 295 task_type=task_type, 296 workspace=workspace, 297 use_cache=use_metric_cache, 298 ) 300 mean_score = evaluation[\"mean\"].to_dict() 301 errors = evaluation[\"errors\"].to_dict()

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\pydantic\decorator.py:40, in pydantic.decorator.validate_arguments.validate.wrapper_function()

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\pydantic\decorator.py:134, in pydantic.decorator.ValidatedFunction.call()

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\pydantic\decorator.py:206, in pydantic.decorator.ValidatedFunction.execute()

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\synthcity\metrics\eval.py:204, in Metrics.evaluate(X_gt, X_syn, X_train, X_ref_syn, X_augmented, reduction, n_histogram_bins, metrics, task_type, random_state, workspace, use_cache) 201 metrics = Metrics.list() 203 Xgt, = X_gt.encode() --> 204 Xsyn, = X_syn.encode() 206 if X_train: 207 Xtrain, = X_train.encode()

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\synthcity\plugins\core\dataloader.py:244, in DataLoader.encode(self, encoders) 242 encoded[col] = encoder.transform(encoded[col]).values 243 encoders[col] = encoder --> 244 return self.from_info(encoded, self.info()), encoders

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\synthcity\plugins\core\dataloader.py:647, in SurvivalAnalysisDataLoader.from_info(data, info) 644 if not isinstance(data, pd.DataFrame): 645 raise ValueError(f\"Invalid data type {type(data)}\") --> 647 return SurvivalAnalysisDataLoader( 648 data, 649 target_column=info[\"target_column\"], 650 time_to_event_column=info[\"time_to_event_column\"], 651 sensitive_features=info[\"sensitive_features\"], 652 important_features=info[\"important_features\"], 653 time_horizons=info[\"time_horizons\"], 654 fairness_column=info[\"fairness_column\"], 655 )

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\pydantic\decorator.py:40, in pydantic.decorator.validate_arguments.validate.wrapper_function()

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\pydantic\decorator.py:134, in pydantic.decorator.ValidatedFunction.call()

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\pydantic\decorator.py:206, in pydantic.decorator.ValidatedFunction.execute()

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\synthcity\plugins\core\dataloader.py:531, in SurvivalAnalysisDataLoader.init(self, data, time_to_event_column, target_column, time_horizons, sensitive_features, important_features, fairness_column, random_state, train_size, **kwargs) 529 row_diff = data.shape[0] - data_filtered.shape[0] 530 if row_diff > 0: --> 531 raise ValueError( 532 f\"The time_to_event_column contains {row_diff} values less than or equal to zero. Please remove them.\" 533 ) 535 if len(time_horizons) == 0: 536 time_horizons = np.linspace(T.min(), T.max(), num=5)[1:-1].tolist()

ValueError: The time_to_event_column contains 1 values less than or equal to zero. Please remove them." }

## System Information Python 3.10.11

## Additional Context I'm not sure if its relevant, but even if I set subset_df = subset_df[subset_df['tstop'] > 0] prior to running any of this code, I still get the same error

robsdavis commented 2 weeks ago

Hi @uclrmhigid, Thanks for submitting this issue.

The error you are seeing comes from this part of the code:

        T = data[time_to_event_column]
        data_filtered = data[T > 0]
        row_diff = data.shape[0] - data_filtered.shape[0]
        if row_diff > 0:
            raise ValueError(
                f"The time_to_event_column contains {row_diff} values less than or equal to zero. Please remove them."
            )

Does your time to event column contain any values not greater than 0? If yes then, this is the expected behaviour and you will need to remove or re-label these datapoints.

If no, are able to share your data somehow for me to re-create this? if not, can you create a toy dataset that causes this issue?