salesforce / OmniXAI

OmniXAI: A Library for eXplainable AI
BSD 3-Clause "New" or "Revised" License
867 stars 91 forks source link

Persistence of `TabularTransform` constructor default arguments #88

Open phantom-duck opened 1 year ago

phantom-duck commented 1 year ago

Hello! I think there is a potential issue when one tries to create several TabularTransform instances (omnixai.preprocessing.tabular.TabularTransform).

Specifically, the __init__ function of TabularTransform has default values for its 3 arguments which are instantiations of transformation classes. Due to python's way of parsing, these instantiations are evaluated when the function is defined, and never again. This means that the default values are objects shared across all instances of TabularTransform.

This leads to highly unintuitive and misleading results in some cases. As an example:

import pandas as pd
from omnixai.data.tabular import Tabular
from omnixai.preprocessing.tabular import TabularTransform

df1 = pd.DataFrame({"cat_feat": ["A", "B", "B", "A"], "num_feat": [1, 2, 3, 4]})
print(df1)
# output:
#   cat_feat  num_feat
# 0        A         1
# 1        B         2
# 2        B         3
# 3        A         4

tabular1 = Tabular(df1, categorical_columns=["cat_feat"])
transform1 = TabularTransform().fit(tabular1)
print(transform1.cate_shape)
# output: 2
print(transform1.cate_transform.get_feature_names())
# output: ['x0_A' 'x0_B']

df2 = pd.DataFrame({"cat_feat_1": ["A", "B", "B", "A"], "cat_feat_2": ["h", "h", "h", "l"], "num_feat": [1, 2, 3, 4]})
print(df2)
#output:
#   cat_feat_1 cat_feat_2  num_feat
# 0          A          h         1
# 1          B          h         2
# 2          B          h         3
# 3          A          l         4

tabular2 = Tabular(df2, categorical_columns=["cat_feat_1", "cat_feat_2"])
transform2 = TabularTransform().fit(tabular2)
print(transform2.cate_shape)
# output: 4
print(transform2.cate_transform.get_feature_names())
# output: ['x0_A' 'x0_B' 'x1_h' 'x1_l']

print(transform1.cate_shape)
# output: 2
print(transform1.cate_transform.get_feature_names())
# output: ['x0_A' 'x0_B' 'x1_h' 'x1_l']

It can be observed that despite not touching the transform1 object, its get_feature_names method returns the same array as transform2 in the end.

yangwenz commented 1 year ago

Thanks for mentioning this issue, we will move the default parameters into the init function in the next minor version.