sdv-dev / CTGAN

Conditional GAN for generating synthetic tabular data.
Other
1.23k stars 279 forks source link

How to load this model directly to generate data after saving it #361

Closed RedBlue01 closed 4 months ago

RedBlue01 commented 4 months ago

Environment details

Problem description

This is my code. I want to fit a certain epoch and save the model, and then directly use this model to generate data. But the attempt failed and an error was reported. `from sdv.single_table import CTGANSynthesizer synthesizer=CTGANSynthesizer.load( filepath='/home/visitor/Huang/Analytical-Method/GAN/my_synthesizer_mini_e200NEW.pkl' ) synthetic_data = synthesizer.sample(num_rows=10) synthetic_data.to_csv('/home/visitor/Huang/Analytical-Method/GAN/synthetic_data.csv', index=False)

print(synthetic_data) print('Done')`

What I already tried

I tried to view the anaconda3/envs/AM/lib/python3.10/site-packages/sdv/data_processing/data_processor.py file, but my level is limited and I don’t know how to solve it. The following is my current situation.

Traceback (most recent call last):
  File "/home/visitor/anaconda3/envs/AM/lib/python3.10/site-packages/sdv/single_table/base.py", line 761, in _sample_with_progress_bar
    sampled = self._sample_in_batches(
  File "/home/visitor/anaconda3/envs/AM/lib/python3.10/site-packages/sdv/single_table/base.py", line 692, in _sample_in_batches
    sampled_rows = self._sample_batch(
  File "/home/visitor/anaconda3/envs/AM/lib/python3.10/site-packages/sdv/single_table/base.py", line 624, in _sample_batch
    sampled, num_valid = self._sample_rows(
  File "/home/visitor/anaconda3/envs/AM/lib/python3.10/site-packages/sdv/single_table/base.py", line 563, in _sample_rows
    sampled = self._data_processor.reverse_transform(sampled)
  File "/home/visitor/anaconda3/envs/AM/lib/python3.10/site-packages/sdv/data_processing/data_processor.py", line 827, in reverse_transform
    raise NotFittedError()
sdv.data_processing.errors.NotFittedError

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/visitor/Huang/Analytical-Method/GAN/myCTGAN_Fit_mini load.py", line 5, in <module>
    synthetic_data = synthesizer.sample(num_rows=10)
  File "/home/visitor/anaconda3/envs/AM/lib/python3.10/site-packages/sdv/single_table/base.py", line 800, in sample
    return self._sample_with_progress_bar(
  File "/home/visitor/anaconda3/envs/AM/lib/python3.10/site-packages/sdv/single_table/base.py", line 770, in _sample_with_progress_bar
    handle_sampling_error(output_file_path == TMP_FILE_NAME, output_file_path, error)
  File "/home/visitor/anaconda3/envs/AM/lib/python3.10/site-packages/sdv/single_table/utils.py", line 112, in handle_sampling_error
    raise type(sampling_error)(error_msg + '\n' + str(sampling_error))
sdv.data_processing.errors.NotFittedError: Error: Sampling terminated. Partial results are stored in a temporary file: .sample.csv.temp. This file will be overridden the next time you sample. Please rename the file if you wish to save these results.
npatki commented 4 months ago

Hi @RedBlue01, nice to meet you.

The error message seems to indicate that the synthesizer you are loading in was never fitted -- therefore, it is not possible to sample from it. Did you create the original synthesizer (saved as my_synthesizer_mini_e200NEW.pkl)? If so, could you share the code that went into creating that pkl file?

BTW instead of using the CTGAN library directly, I would highly recommend you move to the SDV library. You can access the CTGAN Synthesizer via SDV. Doing so will allow you to make use of additional features -- such as better data pre-processing, customizations such as constraints, and conditional sampling. Here is a tutorial that uses CTGAN via the SDV library.

RedBlue01 commented 4 months ago

Hi @npatki , Thank you very much for responding to this question, and I'm sorry I send message until now. And here's my code about create the original synthesizer:

import pandas as pd

data = pd.read_csv('/home/visitor/Huang/Analytical-Method/column_123after.csv', usecols=[0, 2])

from sdv.metadata import SingleTableMetadata
metadata=SingleTableMetadata()
metadata.detect_from_dataframe(data)
python_dict = metadata.to_dict()
print(data)
print(python_dict)

from sdv.single_table import CTGANSynthesizer
synthesizer = CTGANSynthesizer(
metadata, # required
enforce_rounding=True,
epochs=200,
verbose=True
)
synthesizer.save(
filepath='/home/visitor/Huang/Analytical-Method/GAN/my_synthesizer_e200NEW.pkl'
)

synthesizer.fit(data)
synthesizer.get_loss_values()

synthetic_data = synthesizer.sample(num_rows=10)

print(synthetic_data)
print('Done')

And thank you so much for what you have done. I already "pip install sdv"ed. And it's an amazing work.

npatki commented 4 months ago

Hi @RedBlue01, thanks for sharing your code.

The problem is that you are saving your synthesizer before you are fitting it. I would recommend saving the synthesizer after you call the fit function. The fitting process is where the machine learning happens. You would want to include that in the saved file so saving should happen after that.

synthesizer.fit(data)

synthesizer.save(
filepath='/home/visitor/Huang/Analytical-Method/GAN/my_synthesizer_e200NEW.pkl'
)

Keep in mind that when you call save, you will save the state of the synthesizer at that point of time only, as a pkl file.

RedBlue01 commented 4 months ago

Hi @npatki , thank you so much for your help. I finally successfully solved this problem that has troubled me for a long time. Indeed, I never thought that it was a problem with the order of save and fit. The code works great. What an amazing work, thank you and your team again for your work and dedication.