sdv-dev / CTGAN

Conditional GAN for generating synthetic tabular data.
Other
1.27k stars 287 forks source link

Cannot save `CTGANSynthesizer` after sampling (`TypeError`) #270

Closed npatki closed 1 year ago

npatki commented 1 year ago

Environment Details

Error Description

When using the CTGANSynthesizer from the new SDV 1.0 branch, I'm unable to save the trained synthesizer object.

Steps to reproduce

from sdv.datasets.demo import download_demo
from sdv.single_table import CTGANSynthesizer

real_data, metadata = download_demo(
    modality='single_table',
    dataset_name='fake_hotel_guests'
)

synthesizer = CTGANSynthesizer(metadata)
synthesizer.fit(real_data)
synthetic_data = synthesizer.sample(num_rows=500)
synthesizer.save('my_synthesizer.pkl')

Stack Trace

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-10-70c571761350>](https://localhost:8080/#) in <module>
----> 1 synthesizer.save('my_synthesizer.pkl')
      2 
      3 # in the future
      4 synthesizer = CTGANSynthesizer.load('my_synthesizer.pkl')

2 frames
[/usr/local/lib/python3.8/dist-packages/sdv/single_table/base.py](https://localhost:8080/#) in save(self, filepath)
   1075         """
   1076         with open(filepath, 'wb') as output:
-> 1077             cloudpickle.dump(self, output)
   1078 
   1079     @classmethod

[/usr/local/lib/python3.8/dist-packages/cloudpickle/cloudpickle_fast.py](https://localhost:8080/#) in dump(obj, file, protocol, buffer_callback)
     53         compatibility with older versions of Python.
     54         """
---> 55         CloudPickler(
     56             file, protocol=protocol, buffer_callback=buffer_callback
     57         ).dump(obj)

[/usr/local/lib/python3.8/dist-packages/cloudpickle/cloudpickle_fast.py](https://localhost:8080/#) in dump(self, obj)
    630     def dump(self, obj):
    631         try:
--> 632             return Pickler.dump(self, obj)
    633         except RuntimeError as e:
    634             if "recursion" in e.args[0]:

TypeError: cannot pickle 'torch._C.Generator' object