sdv-dev / SDV

Synthetic data generation for tabular data
https://docs.sdv.dev/sdv
Other
2.32k stars 305 forks source link

Can't save model with string primary keys #1086

Closed jmlane8 closed 1 year ago

jmlane8 commented 1 year ago

Environment Details

Error Description

I am trying to create synthetic data for a tabular data. The primary key for this table is a string: ABCD0001, ABCD0002. I was able to train the model and create the data, but when I went to save it: sdv.save("tmp.pkl") I get this error: TypeError: cannot pickle 'generator' object

Steps to reproduce

Create a table with a primary key for a string, train the model, then try to save. METADATA = { 'tables': { 'mytable': { 'primary_key': 'MyID', 'fields': { 'MyID': { "name":"MyID", 'type': 'id', 'subtype': 'string', 'regex': 'ABCD\d{4}' },

from sdv import SDV sdv = SDV() sdv.fit(metadata, tables) sdv.save("tmp.pkl")

npatki commented 1 year ago

Hi @jmlane8, nice to meet you! Unfortunately, I'm unable to replicate this error given the metadata and information above.

One observation: The METADATA variable you provided above seems to be missing some closing brackets. I've pasted below the definition I am using.

METADATA = {
  'tables': {
    'mytable': {
      'primary_key': 'MyID',
      'fields': {
        'MyID': {
            "name":"MyID",
            'type': 'id',
            'subtype': 'string',
            'regex': 'ABCD\d{4}'}
      }
    }
  }
}

Here is the code I ran:

import pandas as pd
from sdv import SDV

# create some fake data that matches the metadata
mytable = pd.DataFrame(data={
    'MyID': ['ABCD0001', 'ABCD0002', 'ABCD0003', 'ABCD0004', 'ABCD0005']
})

# put this data into the right format
tables = {
    'mytable': mytable
}

# model and save the SDV
sdv = SDV()
sdv.fit(METADATA, tables)
sdv.save("tmp.pkl")

The code above ran for me without any issues. Could you try running it to see if you get the same error?

I should also clarify that the METADATA should contain information about all the tables and columns that you are trying to model (in the tables variable). If this is not the case, you may want to update the METADATA so it is reflective of the full dataset.

jmlane8 commented 1 year ago

Hi @npatki, nice to meet you too! Thank you for the code sample! I just tried running sdv.sample() before save() in the code, and got the same error. However, if I put sample after save in my own code, it works. Thank you so much!

npatki commented 1 year ago

Hi @jmlane8, got it! I can replicate your error if I add a SDV.sample() before saving it. Inputting my code and output below for reference.

We'll keep this issue open and update with progress.

Workaround: For now, this works if you save the model first. You can load it in the future and sample from it.

Replication Steps

import pandas as pd
from sdv import SDV

METADATA = {
  'tables': {
    'mytable': {
      'primary_key': 'MyID',
      'fields': {
        'MyID': {
            "name":"MyID",
            'type': 'id',
            'subtype': 'string',
            'regex': 'ABCD\d{4}'}
      }
    }
  }
}

# create some fake data that matches the metadata
mytable = pd.DataFrame(data={
    'MyID': ['ABCD0001', 'ABCD0002', 'ABCD0003', 'ABCD0004', 'ABCD0005']
})

# put this data into the right format
tables = {
    'mytable': mytable
}

# model and save the SDV
sdv = SDV()
sdv.fit(METADATA, tables)
sdv.sample()
sdv.save("tmp.pkl")

Output:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-8-b32ed869af7e>](https://localhost:8080/#) in <module>
----> 1 model.save('test.pkl')

12 frames
[/usr/local/lib/python3.7/dist-packages/sdv/relational/base.py](https://localhost:8080/#) in save(self, path)
    195 
    196         with open(path, 'wb') as output:
--> 197             cloudpickle.dump(self, output)
    198 
    199     @classmethod

[/usr/local/lib/python3.7/dist-packages/cloudpickle/cloudpickle_fast.py](https://localhost:8080/#) in dump(obj, file, protocol)
     86         compatibility with older versions of Python.
     87         """
---> 88         CloudPickler(file, protocol=protocol).dump(obj)
     89 
     90     def dumps(obj, protocol=None):

[/usr/local/lib/python3.7/dist-packages/cloudpickle/cloudpickle_fast.py](https://localhost:8080/#) in dump(self, obj)
    630     def dump(self, obj):
    631         try:
--> 632             return Pickler.dump(self, obj)
    633         except RuntimeError as e:
    634             if "recursion" in e.args[0]:

[/usr/lib/python3.7/pickle.py](https://localhost:8080/#) in dump(self, obj)
    435         if self.proto >= 4:
    436             self.framer.start_framing()
--> 437         self.save(obj)
    438         self.write(STOP)
    439         self.framer.end_framing()

[/usr/lib/python3.7/pickle.py](https://localhost:8080/#) in save(self, obj, save_persistent_id)
    547 
    548         # Save the reduce() output and finally memoize the object
--> 549         self.save_reduce(obj=obj, *rv)
    550 
    551     def persistent_id(self, obj):

[/usr/lib/python3.7/pickle.py](https://localhost:8080/#) in save_reduce(self, func, args, state, listitems, dictitems, obj)
    660 
    661         if state is not None:
--> 662             save(state)
    663             write(BUILD)
    664 

[/usr/lib/python3.7/pickle.py](https://localhost:8080/#) in save(self, obj, save_persistent_id)
    502         f = self.dispatch.get(t)
    503         if f is not None:
--> 504             f(self, obj) # Call unbound method with explicit self
    505             return
    506 

[/usr/lib/python3.7/pickle.py](https://localhost:8080/#) in save_dict(self, obj)
    857 
    858         self.memoize(obj)
--> 859         self._batch_setitems(obj.items())
    860 
    861     dispatch[dict] = save_dict

[/usr/lib/python3.7/pickle.py](https://localhost:8080/#) in _batch_setitems(self, items)
    883                 for k, v in tmp:
    884                     save(k)
--> 885                     save(v)
    886                 write(SETITEMS)
    887             elif n:

[/usr/lib/python3.7/pickle.py](https://localhost:8080/#) in save(self, obj, save_persistent_id)
    502         f = self.dispatch.get(t)
    503         if f is not None:
--> 504             f(self, obj) # Call unbound method with explicit self
    505             return
    506 

[/usr/lib/python3.7/pickle.py](https://localhost:8080/#) in save_dict(self, obj)
    857 
    858         self.memoize(obj)
--> 859         self._batch_setitems(obj.items())
    860 
    861     dispatch[dict] = save_dict

[/usr/lib/python3.7/pickle.py](https://localhost:8080/#) in _batch_setitems(self, items)
    888                 k, v = tmp[0]
    889                 save(k)
--> 890                 save(v)
    891                 write(SETITEM)
    892             # else tmp is empty, and we're done

[/usr/lib/python3.7/pickle.py](https://localhost:8080/#) in save(self, obj, save_persistent_id)
    522             reduce = getattr(obj, "__reduce_ex__", None)
    523             if reduce is not None:
--> 524                 rv = reduce(self.proto)
    525             else:
    526                 reduce = getattr(obj, "__reduce__", None)

TypeError: can't pickle generator objects
npatki commented 1 year ago

Good news! This bug is fixed in the new, SDV 1.0 (Beta!) release.

Note that in this new release, we no longer have an overall SDV() object. The equivalent is to use the HMASynthesizer. For more information see: