opendp / smartnoise-sdk

Tools and service for differentially private processing of tabular and relational data
MIT License
253 stars 68 forks source link

Add conditional sampling for synthesizers #496

Closed neat-web closed 1 year ago

neat-web commented 2 years ago

User story

As a user of the smartnoise-synth library, I want to generate samples that satisfy certain conditions. This enables me e.g. to generate synthetic data on the fly without permanent data storage.

Detailed description

In many cases a user generates differentially private synthetic data, saves it as a file and then performs queries. However, for some scenarios, it is beneficial to simplify the process and just directly output specific samples.

For example, the Synthetic Data Vault (SDV) library already implemented such functionality in this feature request.

Proposed solution

Most synthesizers can not directly generate values that satisfy certain conditions. Possibly CTGAN based synthesizers can do that, like in the SDV implementation, but I did not look into that. Therefore, I suggest using a simple rejection sampling process where the method samples up to max_tries sample batches that satisfy certain conditions.

The syntax of the conditions should be easy to understand for a user. SDV uses custom sdv.sampling.Condition objects, which I personally find unflexible. The pandasql library is already indirectly installed as a requirement. I propose to define the conditions as a feature-rich SQL statement. In contrast to SDVs approach, this also enables a user to define conditions that check for inequality and combine them using logical operators, e.g. WHERE age < 50 AND income > 1000.

I would be willing to contribute this feature in a few weeks. Feel free to propose any changes or make suggestions.

joshua-oss commented 2 years ago

That's a great idea. As you mention, some of the synthesizers support sampling from a conditional distribution, but rejection sampling would be the most general-purpose way to provide support.

It looks like SDV calling convention is like:

synth.fit(dataset)
records = synth.sample(n=n, conditions=conditions)

We could use the same parameter name, and just take a SQL string, or we could use a different parameter name to make it clear that our conditions aren't compatible. For example:

synth.fit(pums)
records = synth.sample(1000, where='age < 50 AND income > 1000`)

I don't have a strong preference; open to suggestions.

One caveat is that we support arrays of tuples and numpy arrays as first-class input and output types, so it's not uncommon to have tabular data that doesn't have column names. For that case, I suppose we could allow people to pass in an optional columns parameter, or we could support SQL syntax with numeric column names using standard column name escaping rules. Either approach would work for me.

We actually have built in support for SQL expression evaluation that doesn't rely on pandas or SQLite. Here is an end-to-end example of doing rejection sampling with the SmartNoise AST, which could serve as a starting point for implementing this on the sample() function:

import pandas as pd
pums = pd.read_csv("PUMS.csv")
colnames = [c for c in list(pums.columns)]
# use tuples to match internal representation that synthesizers use
pums_tuples = [tuple([c for c in t[1:]]) for t in pums.itertuples()]

# fit the synthesizer
from snsynth import Synthesizer
synth = Synthesizer.create("mwem", epsilon=3.0, split_factor=2, verbose=True)
synth.fit(pums_tuples, preprocessor_eps=1.0)

# set the condition, n, and max_retries
condition = 'age / 5  > LOG(income)'
# condition = 'age < 50 AND income > 1000'
max_tries = 20
n = 1000

# parse the condition into AST
dummy = "SELECT * FROM FOO WHERE " + condition
from snsql.sql.parse import QueryParser
q = QueryParser().query(dummy)
cond = q.where.condition

# perform rejection sampling
for _ in range(n):
    for _ in range(max_tries):
        row = synth.sample(1)
        bindings = {c: row[0][i] for i, c in enumerate(colnames)}
        if cond.evaluate(bindings):
            print(row)
            break
    else:
        raise ValueError("Unable to find a row that satisfies the condition")
neat-web commented 2 years ago

Firstly, thank you for your detailed feedback.

We could use the same parameter name, and just take a SQL string, or we could use a different parameter name to make it clear that our conditions aren't compatible. For example:

I assume some users have experience with SDV. Due to that I would prefer to make it clear that the conditions are not compatible. Maybe it is a good idea to implement the functionality in a new method (like sample_conditions() in SDV) to not introduce more rejection sampling related keyword arguments in the basic sample() method?

One caveat is that we support arrays of tuples and numpy arrays as first-class input and output types, so it's not uncommon to have tabular data that doesn't have column names. For that case, I suppose we could allow people to pass in an optional columns parameter, or we could support SQL syntax with numeric column names using standard column name escaping rules. Either approach would work for me.

I do not have a strong preference and will decide upon that during the implementation.

We actually have built in support for SQL expression evaluation that doesn't rely on pandas or SQLite. Here is an end-to-end example of doing rejection sampling with the SmartNoise AST, which could serve as a starting point for implementing this on the sample() function:

Nice! I did not know about that and will use the example as a starting point.

joshua-oss commented 2 years ago

That sounds great. Doing something like sample_conditions would have the added benefit that the implementation could be done entirely on the Synthesizer base class and would be picked up by all synthesizers. I'll defer to you on this.