Closed michaelpollmann closed 1 year ago
The DataWrapper
object saves all the summary statistics it needs about the original data to generate new data. You should be able to save and reload it via pickle
import wgan, pickle
data_wrapper = wgan.DataWrapper(df, continuous_vars, categorical_vars, context_vars)
with open("/path/wrapper.pkl", "wb") as f:
pickle.dump(data_wrapper, f)
with open("/path/wrapper.pkl", "rb") as f:
loaded_data_wrapper = pickle.load(f)
If it saves all the summary statistics it needs, why are we asking the user to supply a dataframe during sampling? Because of the context_vars
, which the GAN is conditioned on and doesn't generate itself, so their values must be supplied during generation.
Based on your question, I assume you don't use any context_vars
. In that case, the current requirement to supply a dataframe is indeed unnecessary - mainly its shape is used to infer the desired sample size. With the current package, you could e.g. save the first row of either the real or a generated dataframe:
df.iloc[0:1].to_feather("df0.feather") # save
df0 = pd.read_feather("df0.feather") # load
and supply df0.sample(desired_sample_size, replace=True)
as the second argument to data_wrapper.apply_generator(...)
.
But I agree that that's unintuitive for users who don't use any context_vars
. It would be nice to add an option to supply desired_sample_size
directly, instead of the super-sampled df0
(an MVP implementation would just save df.iloc[0:1]
when the data_wrapper
is initialized and reuse it during .apply_generator
). But having two mutually-exclusive arguments introduces a new source of confusion and requires some updated docs, so I'd only add it if you feel that it would be a significant net value-add.
Thank you, saving the DataWrapper
indeed seems to work well for my first issue!
I am using context_vars
. It seems to me that the dataframe that I need to pass to apply_generator
still has to contain the columns that are being generated (not just the columns of the context_vars
). It would be great if I could supply a dataframe with just the context_vars
in it, and the other variables will be created based e.g. on information stored in the DataWrapper
.
The documentation for DataWrapper
states that the input dataframe needs to include the columns to be generated:
https://ds-wgan.readthedocs.io/en/latest/api.html#data-wrapper
"df (pandas.DataFrame) – Training data frame, includes both variables to be generated, and variables to be conditioned on"
A small example suggests that I indeed need those columns:
import pandas as pd
import wgan
import torch
import pickle
i = 4
df = pd.DataFrame()
for newcol in ["A","B"]:
df[newcol] = [i+1, i+2,i*i,i*i*i,i+1, i+2,i*i,i*i*i,i+1, i+2,i*i,i*i*i]
i=i+1
df2 = pd.DataFrame()
for newcol in ["A"]:
df2[newcol] = [i+1, i+2,i*i,i*i*i,i+1, i+2,i*i,i*i*i,i+1, i+2,i*i,i*i*i]
i=i+1
#### Generation
# B | A
categorical_vars = ["B"]
context_vars = ["A"]
data_wrapper = wgan.DataWrapper(df, categorical_vars=categorical_vars, context_vars=context_vars)
spec = wgan.Specifications(data_wrapper, batch_size=4, max_epochs=4)
generator = wgan.Generator(spec)
critic = wgan.Critic(spec)
# train B | A
x, context = data_wrapper.preprocess(df)
wgan.train(generator, critic, x, context, spec)
# save data_wrapper
with open("/path/wrapper_AB.pkl", "wb") as f:
pickle.dump(data_wrapper, f)
# load data wrapper
with open("/path/wrapper_AB.pkl", "rb") as f:
loaded_data_wrapper = pickle.load(f)
Then the following works
# generate some artificial data
df_generated = data_wrapper.apply_generator(generator, df)
df_generated
But using df2
does not work:
# generate some artificial data using loaded_data_wrapper and df2
df_generated2 = loaded_data_wrapper.apply_generator(generator, df2)
df_generated2
returns error
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
Input In [6], in <cell line: 2>()
1 # generate some artificial data using loaded_data_wrapper and df2
----> 2 df_generated2 = loaded_data_wrapper.apply_generator(generator, df2)
3 df_generated2
File C:\ProgramData\Miniconda3\lib\site-packages\wgan\wgan.py:143, in DataWrapper.apply_generator(self, generator, df)
141 generator.to("cpu")
142 original_columns = df.columns
--> 143 x, context = self.preprocess(df)
144 x_hat = generator(context)
145 df_hat = self.deprocess(x_hat, context)
File C:\ProgramData\Miniconda3\lib\site-packages\wgan\wgan.py:92, in DataWrapper.preprocess(self, df)
90 x, context = [(x-m)/s for x,m,s in zip([x, context], self.means, self.stds)]
91 if len(self.variables["categorical"]) > 0:
---> 92 categorical = torch.tensor(pd.get_dummies(df[self.variables["categorical"]], columns=self.variables["categorical"]).to_numpy())
93 x = torch.cat([x, categorical.to(torch.float)], -1)
94 total = torch.cat([x, context], -1)
File C:\ProgramData\Miniconda3\lib\site-packages\pandas\core\frame.py:3511, in DataFrame.__getitem__(self, key)
3509 if is_iterator(key):
3510 key = list(key)
-> 3511 indexer = self.columns._get_indexer_strict(key, "columns")[1]
3513 # take() does not accept boolean indexers
3514 if getattr(indexer, "dtype", None) == bool:
File C:\ProgramData\Miniconda3\lib\site-packages\pandas\core\indexes\base.py:5782, in Index._get_indexer_strict(self, key, axis_name)
5779 else:
5780 keyarr, indexer, new_indexer = self._reindex_non_unique(keyarr)
-> 5782 self._raise_if_missing(keyarr, indexer, axis_name)
5784 keyarr = self.take(indexer)
5785 if isinstance(key, Index):
5786 # GH 42790 - Preserve name from an Index
File C:\ProgramData\Miniconda3\lib\site-packages\pandas\core\indexes\base.py:5842, in Index._raise_if_missing(self, key, indexer, axis_name)
5840 if use_interval_msg:
5841 key = list(key)
-> 5842 raise KeyError(f"None of [{key}] are in the [{axis_name}]")
5844 not_found = list(ensure_index(key)[missing_mask.nonzero()[0]].unique())
5845 raise KeyError(f"{not_found} not in index")
KeyError: "None of [Index(['B'], dtype='object')] are in the [columns]"
Using loaded_data_wrapper
together with df
(which contains both columns) does work:
# generate some artificial data using loaded_data_wrapper and df
df_generated3 = loaded_data_wrapper.apply_generator(generator, df)
df_generated3
I think it would be convenient if I don't need to create empty columns that are to be filled by the generator, in part because I would like the code to be as reusable as possible, so needlessly hardcoding dataframe column names seems suboptimal.
(saving and loading the DataWrapper
rather than the data itself is already a major improvement for me, so thanks for pointing out that that suffices to solve my first issue!)
I updated the package such that DataWrapper.apply_generator(g, df)
doesn't require df
to contain any columns listed continuous_vars
or categorical_vars
anymore. Without any context_vars
, users can just set df=pd.DataFrame(index=range(desired_sample_size))
.
I didn't bump the version number, so you want to pip3 install
using the --upgrade
or --force-reinstall
flag. Let me know if that works for you!
Thanks for working on simplifying this for me! It now works the way I would like it to work, with much shorter and slightly faster code (for simulating) than what I had before.
A minor point of concern: It does not work with a DataWrapper
that I saved (using pickle
as recommended above) with the previous version of the package because it attempts to access a field (df0
) of the DataWrapper
object that simply didn't exist before. I don't know how many people have an old DataWrapper
saved away somewhere, so I don't know if this makes the update problematic / makes it worthwhile to put in some fallback for old DataWrapper
objects that would then lead to old code needing to be maintained...
Specifically, I got the following error message when trying to generate data (apply_generator
) with the new package version using a DataWrapper
created with the previous package version:
AttributeError Traceback (most recent call last)
Input In [4], in <cell line: 1>()
----> 1 df_generated = data_wrappers[0].apply_generator(generators[0], df.sample(100, replace=True))
2 df_generated
File C:\ProgramData\Miniconda3\lib\site-packages\wgan\wgan.py:145, in DataWrapper.apply_generator(self, generator, df)
143 updated = self.variables["continuous"] + self.variables["categorical"]
144 df = df.drop(updated, axis=1, errors="ignore").reset_index(drop=True).copy()
--> 145 df = self.df0.sample(len(df), replace=True).reset_index(drop=True).join(df)
146 original_columns = df.columns
147 x, context = self.preprocess(df)
AttributeError: 'DataWrapper' object has no attribute 'df0'
you can always install the old package by referring to the git commit before the change:
pip install git+https://github.com/gsbDBI/ds-wgan.git@a65686832ea5bba27f1c6175c252769dc00b2cc3
if you're in an environment with the current package version and you need to downgrade, you probably have to use the --force-reinstall
flag, which is annoying because it reinstalls the dependencies as well. In theory, the --upgrade
flag should also allow for downgrades (without reinstalling dependencies), but it doesn't seem to work for pip install --upgrade git+...
downgrades on my end.
Anyways, there's no need to handle backwards compatibility inside the package, just install the version you need. Otherwise the requested functionality should work now, so I'll close this issue. Feel free to reopen if you run into problems.
It would be really nice if it was possible to use the generator to generate data without access to the real data. In particular, to get the scaling/centering and variable names right in the deprocess function, it would be nice if it was possible to have a function of the package to save (and load) those aspects of the data wrapper that are truly needed (variable names and types, means/standard deviations/values for categorical variables?). It seems like that should in principle be possible such that the user does not need to continue having access to/loading the real data when they want to generate artificial data?
(Simulating one very large data set once after training while the real data is still loaded isn't always a great option, in particular when considering very large samples)