Closed neat-web closed 1 year ago
That's a great idea. As you mention, some of the synthesizers support sampling from a conditional distribution, but rejection sampling would be the most general-purpose way to provide support.
It looks like SDV calling convention is like:
synth.fit(dataset)
records = synth.sample(n=n, conditions=conditions)
We could use the same parameter name, and just take a SQL string, or we could use a different parameter name to make it clear that our conditions aren't compatible. For example:
synth.fit(pums)
records = synth.sample(1000, where='age < 50 AND income > 1000`)
I don't have a strong preference; open to suggestions.
One caveat is that we support arrays of tuples and numpy arrays as first-class input and output types, so it's not uncommon to have tabular data that doesn't have column names. For that case, I suppose we could allow people to pass in an optional columns
parameter, or we could support SQL syntax with numeric column names using standard column name escaping rules. Either approach would work for me.
We actually have built in support for SQL expression evaluation that doesn't rely on pandas or SQLite. Here is an end-to-end example of doing rejection sampling with the SmartNoise AST, which could serve as a starting point for implementing this on the sample()
function:
import pandas as pd
pums = pd.read_csv("PUMS.csv")
colnames = [c for c in list(pums.columns)]
# use tuples to match internal representation that synthesizers use
pums_tuples = [tuple([c for c in t[1:]]) for t in pums.itertuples()]
# fit the synthesizer
from snsynth import Synthesizer
synth = Synthesizer.create("mwem", epsilon=3.0, split_factor=2, verbose=True)
synth.fit(pums_tuples, preprocessor_eps=1.0)
# set the condition, n, and max_retries
condition = 'age / 5 > LOG(income)'
# condition = 'age < 50 AND income > 1000'
max_tries = 20
n = 1000
# parse the condition into AST
dummy = "SELECT * FROM FOO WHERE " + condition
from snsql.sql.parse import QueryParser
q = QueryParser().query(dummy)
cond = q.where.condition
# perform rejection sampling
for _ in range(n):
for _ in range(max_tries):
row = synth.sample(1)
bindings = {c: row[0][i] for i, c in enumerate(colnames)}
if cond.evaluate(bindings):
print(row)
break
else:
raise ValueError("Unable to find a row that satisfies the condition")
Firstly, thank you for your detailed feedback.
We could use the same parameter name, and just take a SQL string, or we could use a different parameter name to make it clear that our conditions aren't compatible. For example:
I assume some users have experience with SDV. Due to that I would prefer to make it clear that the conditions are not compatible. Maybe it is a good idea to implement the functionality in a new method (like sample_conditions()
in SDV) to not introduce more rejection sampling related keyword arguments in the basic sample()
method?
One caveat is that we support arrays of tuples and numpy arrays as first-class input and output types, so it's not uncommon to have tabular data that doesn't have column names. For that case, I suppose we could allow people to pass in an optional
columns
parameter, or we could support SQL syntax with numeric column names using standard column name escaping rules. Either approach would work for me.
I do not have a strong preference and will decide upon that during the implementation.
We actually have built in support for SQL expression evaluation that doesn't rely on pandas or SQLite. Here is an end-to-end example of doing rejection sampling with the SmartNoise AST, which could serve as a starting point for implementing this on the
sample()
function:
Nice! I did not know about that and will use the example as a starting point.
That sounds great. Doing something like sample_conditions
would have the added benefit that the implementation could be done entirely on the Synthesizer
base class and would be picked up by all synthesizers. I'll defer to you on this.
User story
As a user of the
smartnoise-synth
library, I want to generate samples that satisfy certain conditions. This enables me e.g. to generate synthetic data on the fly without permanent data storage.Detailed description
In many cases a user generates differentially private synthetic data, saves it as a file and then performs queries. However, for some scenarios, it is beneficial to simplify the process and just directly output specific samples.
For example, the Synthetic Data Vault (SDV) library already implemented such functionality in this feature request.
Proposed solution
Most synthesizers can not directly generate values that satisfy certain conditions. Possibly
CTGAN
based synthesizers can do that, like in the SDV implementation, but I did not look into that. Therefore, I suggest using a simple rejection sampling process where the method samples up tomax_tries
sample batches that satisfy certain conditions.The syntax of the conditions should be easy to understand for a user. SDV uses custom
sdv.sampling.Condition
objects, which I personally find unflexible. Thepandasql
library is already indirectly installed as a requirement. I propose to define the conditions as a feature-rich SQL statement. In contrast to SDVs approach, this also enables a user to define conditions that check for inequality and combine them using logical operators, e.g.WHERE age < 50 AND income > 1000
.I would be willing to contribute this feature in a few weeks. Feel free to propose any changes or make suggestions.