microsoft / SynapseML

Simple and Distributed Machine Learning
http://aka.ms/spark
MIT License
5.06k stars 831 forks source link

[BUG] The setReferenceDistribution method of DistributionBalanceMeasures does not correctly handle categories not present in source dataset #2010

Open perezbecker opened 1 year ago

perezbecker commented 1 year ago

SynapseML version

0.11.1

System information

Describe the problem

When setting a custom reference distribution in DistributionBalanceMeasures, the reported measures are correct when all the categories in the reference distribution are present in the dataset and the reference dataset. Nevertheless, when the target distribution contains categories not present in the dataset, measures are incorrect. It is not unusual for the dataset to miss data in some categories which have a small likelihood of occurring, where this issue will crop up.

Code to reproduce issue

Let me showcase the issue with a toy example. We will measure the distance between two distributions of a categorical feature using the Jensen-Shannon distance (JSD):

source = ['red','red','red','green','blue']
target = ['red','red','green','blue','blue']

This is the JSD computed with scipy:

from scipy.spatial import distance
import numpy as np

def jensen_shannon_distance_categorical(x_list, y_list):

    # unique values observed in x and y
    values = set(x_list + y_list)

    x_counts = np.array([x_list.count(value) for value in values])
    y_counts = np.array([y_list.count(value) for value in values])

    x_ratios = x_counts / np.sum(x_counts)  #Optional as JS-D normalizes probability vectors
    y_ratios = y_counts / np.sum(y_counts)

    # Warning: We are computing the JSD using base e logarithms for now to compare the result with SynapseML. 
    # For JSD to be bound between 0 and 1 we need to use base 2 logarithms. 
    # See this issue for details: https://github.com/microsoft/SynapseML/issues/2006  
    return distance.jensenshannon(x_ratios, y_ratios)

jensen_shannon_distance_categorical(source, target)
0.1644921288538882

Let's compute the JSD leveraging the setReferenceDistribution method of DistributionBalanceMeasure to compute the distance:

df = spark.createDataFrame(source, StringType()).toDF("color")
target_reference_dist=[{'red':2/5, 'green':1/5,'blue':2/5}]

distribution_balance_measure = (
    DistributionBalanceMeasure()
    .setSensitiveCols(['color'])
    .setReferenceDistribution(target_reference_dist)
    .transform(df).select("FeatureName","DistributionBalanceMeasure.js_dist")
)

distribution_balance_measure.show(truncate=False)
+-----------+------------------+
|FeatureName|js_dist           |
+-----------+------------------+
|color      |0.1644921288538882|
+-----------+------------------+

Both answers agree, as expected.

Now, we will introduce a new category into the target distribution yellow, which is not present in the source data.

new_target = ['red','red','green','blue','yellow']

We expect the JSD to be larger than in the original example, as the target distribution is even more dissimilar than the original. We confirm our intuition with the scipy JSD implementation:

jensen_shannon_distance_categorical(source, new_target)
0.28174895710781067

Now we repeat the calculation with DistributionBalanceMeasure:

new_reference_dist=[{'red':2/5, 'green':1/5,'blue':1/5,'yellow':1/5}]

distribution_balance_measure2 = (
    DistributionBalanceMeasure()
    .setSensitiveCols(['color'])
    .setReferenceDistribution(new_reference_dist)
    .transform(df).select("FeatureName","DistributionBalanceMeasure.js_dist")
)

distribution_balance_measure2.show(truncate=False)
+-----------+------------------+
|FeatureName|js_dist           |
+-----------+------------------+
|color      |0.1003382119401399|
+-----------+------------------+

This is JSD is smaller that the original value and incorrect. This distance is the same if we truncate the reference distribution to exclude the new category and not re-normalize the probabilities for each category:

truncated_new_reference_dist=[{'red':2/5, 'green':1/5,'blue':1/5}]

distribution_balance_measure3 = (
    DistributionBalanceMeasure()
    .setSensitiveCols(['color'])
    .setReferenceDistribution(truncated_new_reference_dist)
    .transform(df).select("FeatureName","DistributionBalanceMeasure.js_dist")
)

distribution_balance_measure3.show(truncate=False)
+-----------+------------------+
|FeatureName|js_dist           |
+-----------+------------------+
|color      |0.1003382119401399|
+-----------+------------------+

Other info / logs

No response

What component(s) does this bug affect?

What language(s) does this bug affect?

What integration(s) does this bug affect?

github-actions[bot] commented 1 year ago

Hey @perezbecker :wave:! Thank you so much for reporting the issue/feature request :rotating_light:. Someone from SynapseML Team will be looking to triage this issue soon. We appreciate your patience.