gsbDBI / ds-wgan

Design of Simulations using WGAN
MIT License
48 stars 16 forks source link

Feature request: generate data (after training) without access to real data (only the relevant summary statistics) #12

Closed michaelpollmann closed 1 year ago

michaelpollmann commented 1 year ago

It would be really nice if it was possible to use the generator to generate data without access to the real data. In particular, to get the scaling/centering and variable names right in the deprocess function, it would be nice if it was possible to have a function of the package to save (and load) those aspects of the data wrapper that are truly needed (variable names and types, means/standard deviations/values for categorical variables?). It seems like that should in principle be possible such that the user does not need to continue having access to/loading the real data when they want to generate artificial data?

(Simulating one very large data set once after training while the real data is still loaded isn't always a great option, in particular when considering very large samples)

Jonas-Metzger commented 1 year ago

The DataWrapper object saves all the summary statistics it needs about the original data to generate new data. You should be able to save and reload it via pickle

import wgan, pickle

data_wrapper = wgan.DataWrapper(df, continuous_vars, categorical_vars, context_vars)
with open("/path/wrapper.pkl", "wb") as f: 
    pickle.dump(data_wrapper, f)

with open("/path/wrapper.pkl", "rb") as f: 
    loaded_data_wrapper = pickle.load(f)

If it saves all the summary statistics it needs, why are we asking the user to supply a dataframe during sampling? Because of the context_vars, which the GAN is conditioned on and doesn't generate itself, so their values must be supplied during generation.

Based on your question, I assume you don't use any context_vars. In that case, the current requirement to supply a dataframe is indeed unnecessary - mainly its shape is used to infer the desired sample size. With the current package, you could e.g. save the first row of either the real or a generated dataframe:

df.iloc[0:1].to_feather("df0.feather")    # save
df0 = pd.read_feather("df0.feather")      # load

and supply df0.sample(desired_sample_size, replace=True) as the second argument to data_wrapper.apply_generator(...).

But I agree that that's unintuitive for users who don't use any context_vars. It would be nice to add an option to supply desired_sample_size directly, instead of the super-sampled df0 (an MVP implementation would just save df.iloc[0:1] when the data_wrapper is initialized and reuse it during .apply_generator). But having two mutually-exclusive arguments introduces a new source of confusion and requires some updated docs, so I'd only add it if you feel that it would be a significant net value-add.

michaelpollmann commented 1 year ago

Thank you, saving the DataWrapper indeed seems to work well for my first issue!

I am using context_vars. It seems to me that the dataframe that I need to pass to apply_generator still has to contain the columns that are being generated (not just the columns of the context_vars). It would be great if I could supply a dataframe with just the context_vars in it, and the other variables will be created based e.g. on information stored in the DataWrapper. The documentation for DataWrapper states that the input dataframe needs to include the columns to be generated: https://ds-wgan.readthedocs.io/en/latest/api.html#data-wrapper "df (pandas.DataFrame) – Training data frame, includes both variables to be generated, and variables to be conditioned on"

A small example suggests that I indeed need those columns:

import pandas as pd
import wgan 
import torch
import pickle

i = 4
df = pd.DataFrame()
for newcol in ["A","B"]:
    df[newcol] = [i+1, i+2,i*i,i*i*i,i+1, i+2,i*i,i*i*i,i+1, i+2,i*i,i*i*i]
    i=i+1

df2 = pd.DataFrame()
for newcol in ["A"]:
    df2[newcol] = [i+1, i+2,i*i,i*i*i,i+1, i+2,i*i,i*i*i,i+1, i+2,i*i,i*i*i]
    i=i+1

#### Generation
# B | A
categorical_vars = ["B"]
context_vars = ["A"]

data_wrapper = wgan.DataWrapper(df, categorical_vars=categorical_vars, context_vars=context_vars)

spec = wgan.Specifications(data_wrapper, batch_size=4, max_epochs=4)
generator = wgan.Generator(spec)
critic = wgan.Critic(spec)

# train B | A
x, context = data_wrapper.preprocess(df)
wgan.train(generator, critic, x, context, spec)

# save data_wrapper
with open("/path/wrapper_AB.pkl", "wb") as f: 
    pickle.dump(data_wrapper, f)

# load data wrapper
with open("/path/wrapper_AB.pkl", "rb") as f: 
    loaded_data_wrapper = pickle.load(f)

Then the following works

# generate some artificial data
df_generated = data_wrapper.apply_generator(generator, df)
df_generated

But using df2 does not work:

# generate some artificial data using loaded_data_wrapper and df2
df_generated2 = loaded_data_wrapper.apply_generator(generator, df2)
df_generated2

returns error

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Input In [6], in <cell line: 2>()
      1 # generate some artificial data using loaded_data_wrapper and df2
----> 2 df_generated2 = loaded_data_wrapper.apply_generator(generator, df2)
      3 df_generated2

File C:\ProgramData\Miniconda3\lib\site-packages\wgan\wgan.py:143, in DataWrapper.apply_generator(self, generator, df)
    141 generator.to("cpu")
    142 original_columns = df.columns
--> 143 x, context = self.preprocess(df)
    144 x_hat = generator(context)
    145 df_hat = self.deprocess(x_hat, context)

File C:\ProgramData\Miniconda3\lib\site-packages\wgan\wgan.py:92, in DataWrapper.preprocess(self, df)
     90 x, context = [(x-m)/s for x,m,s in zip([x, context], self.means, self.stds)]
     91 if len(self.variables["categorical"]) > 0:
---> 92     categorical = torch.tensor(pd.get_dummies(df[self.variables["categorical"]], columns=self.variables["categorical"]).to_numpy())
     93     x = torch.cat([x, categorical.to(torch.float)], -1)
     94 total = torch.cat([x, context], -1)

File C:\ProgramData\Miniconda3\lib\site-packages\pandas\core\frame.py:3511, in DataFrame.__getitem__(self, key)
   3509     if is_iterator(key):
   3510         key = list(key)
-> 3511     indexer = self.columns._get_indexer_strict(key, "columns")[1]
   3513 # take() does not accept boolean indexers
   3514 if getattr(indexer, "dtype", None) == bool:

File C:\ProgramData\Miniconda3\lib\site-packages\pandas\core\indexes\base.py:5782, in Index._get_indexer_strict(self, key, axis_name)
   5779 else:
   5780     keyarr, indexer, new_indexer = self._reindex_non_unique(keyarr)
-> 5782 self._raise_if_missing(keyarr, indexer, axis_name)
   5784 keyarr = self.take(indexer)
   5785 if isinstance(key, Index):
   5786     # GH 42790 - Preserve name from an Index

File C:\ProgramData\Miniconda3\lib\site-packages\pandas\core\indexes\base.py:5842, in Index._raise_if_missing(self, key, indexer, axis_name)
   5840     if use_interval_msg:
   5841         key = list(key)
-> 5842     raise KeyError(f"None of [{key}] are in the [{axis_name}]")
   5844 not_found = list(ensure_index(key)[missing_mask.nonzero()[0]].unique())
   5845 raise KeyError(f"{not_found} not in index")

KeyError: "None of [Index(['B'], dtype='object')] are in the [columns]"

Using loaded_data_wrapper together with df (which contains both columns) does work:

# generate some artificial data using loaded_data_wrapper and df
df_generated3 = loaded_data_wrapper.apply_generator(generator, df)
df_generated3

I think it would be convenient if I don't need to create empty columns that are to be filled by the generator, in part because I would like the code to be as reusable as possible, so needlessly hardcoding dataframe column names seems suboptimal.

(saving and loading the DataWrapper rather than the data itself is already a major improvement for me, so thanks for pointing out that that suffices to solve my first issue!)

Jonas-Metzger commented 1 year ago

I updated the package such that DataWrapper.apply_generator(g, df) doesn't require df to contain any columns listed continuous_vars or categorical_vars anymore. Without any context_vars, users can just set df=pd.DataFrame(index=range(desired_sample_size)).

I didn't bump the version number, so you want to pip3 install using the --upgrade or --force-reinstall flag. Let me know if that works for you!

michaelpollmann commented 1 year ago

Thanks for working on simplifying this for me! It now works the way I would like it to work, with much shorter and slightly faster code (for simulating) than what I had before.

A minor point of concern: It does not work with a DataWrapper that I saved (using pickle as recommended above) with the previous version of the package because it attempts to access a field (df0) of the DataWrapper object that simply didn't exist before. I don't know how many people have an old DataWrapper saved away somewhere, so I don't know if this makes the update problematic / makes it worthwhile to put in some fallback for old DataWrapper objects that would then lead to old code needing to be maintained...

Specifically, I got the following error message when trying to generate data (apply_generator) with the new package version using a DataWrapper created with the previous package version:

AttributeError                            Traceback (most recent call last)
Input In [4], in <cell line: 1>()
----> 1 df_generated = data_wrappers[0].apply_generator(generators[0], df.sample(100, replace=True))
      2 df_generated

File C:\ProgramData\Miniconda3\lib\site-packages\wgan\wgan.py:145, in DataWrapper.apply_generator(self, generator, df)
    143 updated = self.variables["continuous"] + self.variables["categorical"]
    144 df = df.drop(updated, axis=1, errors="ignore").reset_index(drop=True).copy()
--> 145 df = self.df0.sample(len(df), replace=True).reset_index(drop=True).join(df)
    146 original_columns = df.columns
    147 x, context = self.preprocess(df)

AttributeError: 'DataWrapper' object has no attribute 'df0'
Jonas-Metzger commented 1 year ago

you can always install the old package by referring to the git commit before the change:

pip install git+https://github.com/gsbDBI/ds-wgan.git@a65686832ea5bba27f1c6175c252769dc00b2cc3

if you're in an environment with the current package version and you need to downgrade, you probably have to use the --force-reinstall flag, which is annoying because it reinstalls the dependencies as well. In theory, the --upgrade flag should also allow for downgrades (without reinstalling dependencies), but it doesn't seem to work for pip install --upgrade git+... downgrades on my end.

Anyways, there's no need to handle backwards compatibility inside the package, just install the version you need. Otherwise the requested functionality should work now, so I'll close this issue. Feel free to reopen if you run into problems.