jpmml / jpmml-sklearn

Java library and command-line application for converting Scikit-Learn pipelines to PMML
GNU Affero General Public License v3.0
531 stars 117 forks source link

custom transformer not supported #127

Closed agrimabahl closed 4 years ago

agrimabahl commented 4 years ago

I am getting the following error for a custom transformer I wrote to multiply a one-hot transformed columns with another column

java.lang.IllegalArgumentException: The value object (Python class ml_analytics.estimators.scikit_statsmodel_ols.OneHotTransformer) is not a supported Transformer at org.jpmml.sklearn.CastFunction.apply(CastFunction.java:50) at sklearn_pandas.DataFrameMapper.getTransformerList(DataFrameMapper.java:169) at sklearn_pandas.DataFrameMapper.initializeFeatures(DataFrameMapper.java:71) at sklearn.Initializer.encodeFeatures(Initializer.java:41) at sklearn.Transformer.updateAndEncodeFeatures(Transformer.java:118) at sklearn.pipeline.FeatureUnion.encodeFeatures(FeatureUnion.java:45) at sklearn.Transformer.updateAndEncodeFeatures(Transformer.java:118) at sklearn.Composite.encodeFeatures(Composite.java:129) at sklearn2pmml.pipeline.PMMLPipeline.encodePMML(PMMLPipeline.java:208) at org.jpmml.sklearn.Main.run(Main.java:145) at org.jpmml.sklearn.Main.main(Main.java:94) Caused by: java.lang.ClassCastException: Cannot cast net.razorvine.pickle.objects.ClassDict to sklearn.Transformer at java.lang.Class.cast(Class.java:3369) at org.jpmml.sklearn.CastFunction.apply(CastFunction.java:48) ... 10 more

Code for the transformer -

class OneHotTransformer(BaseEstimator, TransformerMixin): def init(self, numeric_columns, string_columns): self._numeric_columns = numeric_columns self._string_columns = string_columns

def fit(self, X, y=None):
    return self

def transform(self, X, y=None):
    enc = LabelBinarizer(sparse_output=True)
    cat_features = enc.fit_transform(X[:,0])
    print(cat_features.shape)
    num = X[:,1]
    p = cat_features.toarray()*num[:, np.newaxis]
    return csr_matrix(p.astype(np.float))

Any suggestion on how I can create pmml for this transformer

vruusmann commented 4 years ago

Exact duplicate of https://github.com/jpmml/sklearn2pmml-plugin/issues/8

You need to develop a Java side for your custom transformer. For starters, see how does the Java side look like for regular LabelBinarizer and OneHotEncoder transformer classes.