jpmml / jpmml-sparkml

Java library and command-line application for converting Apache Spark ML pipelines to PMML
GNU Affero General Public License v3.0
267 stars 80 forks source link

Custom Estimator to add to JPMML-SPARKML #28

Closed fatihtekin closed 7 years ago

fatihtekin commented 7 years ago

Hi

I do have a custom estimator that merges rare categorical values into one value as 'RARE' so that I can group all the rare labels as together. I would like to know if it is possible and how can I add my custom modelconverter as you did for spark standard ml-features.

Ti give an example my custom estimator handles rare columns for categorical columns. So, if there are 1000 categories and only 30 of them are used in most of the time the rest 970 columns will be marked as RARE. So in my model I only save the rare labels. If you need I can paste the code itself as well.

Even if I manage it, I am not sure if jpmml-evaluater will be able to make it run.

vruusmann commented 7 years ago

I do have a custom estimator that merges rare categorical values into one value as 'RARE' so that I can group all the rare labels as together.

Basically, you want to perform mapping between discrete values - map "popular" values back to themselves, and all "unpopular" values to some default value.

The PMML specification provides the MapValues element exactly for this purpose:

<MapValues name="simplified_color" defaultValue="rare" outputColumn="outputValue">
  <FieldColumnPair field="color" column="inputValue"/>
  <InlineTable>
    <row>
      <inputValue>red</inputValue>
      <outputValue>red</outputValue>
    </row>
    <row>
      <inputValue>yellow</inputValue>
      <outputValue>yellow</outputValue>
    </row>
    <row>
      <inputValue>green</inputValue>
      <outputValue>green</outputValue>
    </row>
  </InlineTable>
</MapValues>

The above transformation would keep color values "red", "yellow" and "green" as-is, and change all other color values to "rare" (note the MapValues@defaultValue attribute).

I would like to know if it is possible and how can I add my custom modelconverter as you did for spark standard ml-features.

Here's a code example about generating a MapValues element-based transformation: https://github.com/jpmml/jpmml-sparkml/blob/master/src/main/java/org/jpmml/sparkml/feature/VectorIndexerModelConverter.java

The org.dmg.pmml.InlineTable element is rather tricky to generate in Java, because you need to be working with low-level W3C DOM APIs in some point.

Even if I manage it, I am not sure if jpmml-evaluater will be able to make it run.

JPMML-Evaluator is able to run all PMML documents that conform to PMML 3.X and 4.X specifications.

fatihtekin commented 7 years ago

This is a good suggestion, I will try that. I know it is not relevant but is it normal that some of the categorical columns are missing when I generate pmml model?

fatihtekin commented 7 years ago

Ok, I got it, sorry. The ones not used for mining got cleaned in DataDictionaryCleaner. I will keep this open till I finish my custom converter as I may have more questions.

vruusmann commented 7 years ago

Some more code examples - here's the LabelEncoder transformer from Scikit-Learn, which maps category values to category indexes: https://github.com/jpmml/jpmml-sklearn/blob/master/src/main/java/sklearn/preprocessing/LabelEncoder.java

If your data column contains missing values, and you'd like to map them to the default category (or some special category) as well, then don't forget to specify the MapValues@mapMissingTo attribute.

Anyway, I would recommend you to take the following steps:

  1. Export your Apache Spark ML pipeline into a PMML document.
  2. Modify this PMML document manually - define the MapValues element, and replace all the color field invocations with the simplified_color field invocations.
  3. Run this modified PMML document using the org.jpmml.evaluator.EvaluationExample command-line application from the JPMML-Evaluator project. This is the fastest way to ensure that your PMML changes are structurally valid, and produce desired results.
  4. When step 3 is complete, implement the PMML transformation as org.jpmml.sparkml.FeatureConverter subclass.
fatihtekin commented 7 years ago

Can I assume since I set setHandleInvalid("keep") in StringIndexer, it will already be handled as I use StringIndexer after my RareMerger?

vruusmann commented 7 years ago

I use StringIndexer after my RareMerger

After RangeMerger, there will be only "valid and popular" values left - "red", "yellow", "green" and "rare".

There is no need to specify StringIndexer#handleInvalid property.

fatihtekin commented 7 years ago

I have tried evaulator but i am getting an exception when I try to generate jpmml model.

Exception in thread "main" java.lang.IllegalArgumentException: Expected 7 features, got 6 features at org.jpmml.sparkml.ModelConverter.encodeSchema(ModelConverter.java:147) at org.jpmml.sparkml.ModelConverter.registerModel(ModelConverter.java:161) at org.jpmml.sparkml.ConverterUtil.toPMML(ConverterUtil.java:76) at org.apache.spark.ml.pmml.PmmlTries$.main(PmmlTries.scala:55) at org.apache.spark.ml.pmml.PmmlTries.main(PmmlTries.scala)

val conf = new SparkConf().setAppName("OneHotEncoderExample").setMaster("local[*]")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)
    var df = sqlContext.createDataFrame(Seq(
      //(0, 1),
      (1, 3, 1),
      (2, 3, 0),
      (3, 5, 1),
      (4, 5, 1),
      (5, 6, 0),
      (6, 6, 0),
      (7, 6, 0),
      (8, 999999, 1)
      ,(9, 8999, 0)
      ,(10, 89994343, 1)
    )).toDF("id", "category", "label")
    val indexer = new StringIndexer().setInputCol("category").setOutputCol("categoryIndex").setHandleInvalid("keep")
    val encoder = new OneHotEncoder().setInputCol("categoryIndex").setOutputCol("categoryVec").setDropLast(false)
    val assembler = new VectorAssembler().setInputCols(Array("categoryVec")).setOutputCol("features")
    val lr = new LinearRegression()
      .setMaxIter(10)
      .setRegParam(0.3)
      .setElasticNetParam(0.8)
    var pipeline = new Pipeline().setStages(Array(indexer, encoder, assembler, lr))
    var model = pipeline.fit(df)
    model.transform(df).show(10,false)
    var pmml = ConverterUtil.toPMML(df.schema, model)
vruusmann commented 7 years ago

Exception in thread "main" java.lang.IllegalArgumentException: Expected 7 features, got 6 features

Very interesting - the example Apache Spark ML pipeline appears to generate one extra "shadow" feature. This behavior must be triggered by some configuration option, either by StringIndexer#setHandleInvalid(String) or OneHotEncoder#setDropLast(boolean).

Will investigate. Just to clarify, what is your exact Apache Spark version?

fatihtekin commented 7 years ago

sure, spark is 2.2.0 and scala is 2.11 If I comment setHandleInvalid then it throws below exception. Btw, I need setHandleInvalid.

Using below dependencies "org.jpmml" % "jpmml-sparkml-xgboost" % "1.0-SNAPSHOT" "org.jpmml" % "jpmml-xgboost" % "1.2-SNAPSHOT" "org.jpmml" % "pmml-evaluator" % "1.3.8" "org.jpmml" % "jpmml-sparkml" % "1.2.1"

Exception in thread "main" java.lang.NoSuchMethodError: 
org.jpmml.converter.ModelUtil.createMiningSchema(Lorg/jpmml/converter/Schema;)Lorg/dmg/pmml/MiningSchema;
    at org.jpmml.sparkml.model.LinearRegressionModelConverter.encodeModel(LinearRegressionModelConverter.java:40)
    at org.jpmml.sparkml.model.LinearRegressionModelConverter.encodeModel(LinearRegressionModelConverter.java:30)
    at org.jpmml.sparkml.ModelConverter.registerModel(ModelConverter.java:167)
    at org.jpmml.sparkml.ConverterUtil.toPMML(ConverterUtil.java:76)
    at org.apache.spark.ml.pmml.PmmlTries$.main(PmmlTries.scala:55)
    at org.apache.spark.ml.pmml.PmmlTries.main(PmmlTries.scala)
vruusmann commented 7 years ago

The culprit is StringIndexer#setHandleInvalid("keep"), which causes a special "catch-all-invalids" feature to be appended to the feature list.

The "keep" invalid feature handler seems to be Apache Spark 2.2.X thing. Earlier Apache Spark versions (eg. 2.0.X and 2.1.X) will not let you use it:

java.lang.IllegalArgumentException: strIdx_b78325d25068 parameter handleInvalid given invalid value keep.

As explained in https://github.com/jpmml/jpmml-sparkml/issues/28#issuecomment-321789670, there is no need to specify invalid value handler after you've explicitly categorized features as "popular" and "rare" using the RangeMerger transformer.

Anyway, I intend to make the JPMML-SparkML library smarter about the StringIndexer#handleInvalid property. At minimum, there will be a more relevant and informative exception being thrown.

vruusmann commented 7 years ago

Exception in thread "main" java.lang.NoSuchMethodError: org.jpmml.converter.ModelUtil.createMiningSchema(Lorg/jpmml/converter/Schema;)Lorg/dmg/pmml/MiningSchema;

You have a classpath conflict - Apache Spark contains JPMML-Model library version 1.2.15, which is "shadowing" the latest JPMML-Model 1.3.X.

Please configure your application classpath as specified in JPMML-SparkML README file: https://github.com/jpmml/jpmml-sparkml#library

Personally, I would suggest deleting the offending JPMML-Model library JAR files from the Apache Spark installation (as detailed in section "Modifying Apache Spark installation").

vruusmann commented 7 years ago

Exception in thread "main" java.lang.IllegalArgumentException: Expected 7 features, got 6 features

The name of the "catch-all-invalids" pseudo-category is __unknown. The fix is available in JPMML-SparkML version 1.3.2 (and newer).

In PMML, the corresponding transformation looks like this:

<DerivedField name="handleInvalid(category)" optype="categorical" dataType="string">
    <Apply function="if">
        <Apply function="isIn">
            <FieldRef field="category"/>
            <Constant>6</Constant>
            <Constant>5</Constant>
            <Constant>3</Constant>
            <Constant>8999</Constant>
            <Constant>89994343</Constant>
            <Constant>999999</Constant>
        </Apply>
        <FieldRef field="category"/>
        <Constant>__unknown</Constant>
    </Apply>
</DerivedField>

You could use exactly the same pattern for the RangeMerger transformer - simply replace __unknown with rare.

fatihtekin commented 7 years ago

It is awesome how quickly you have added the functionality. Thanks, really appreciate that. I assume your implementation for StringIndexerModelCOnvertor might throw IllegalArgumentException in case setHandleInvalid("skip") is set. I have added my RareMerger as I have realized that I have missed explaining the null value handling as well. Perhaps code speaks better than my words. Please let me know if I need to explain more. I will need to transform all rare labels to 'RARE', popular labels stay as they are, if training has seen null values then they get translated into category level 'NULL' if not they go to __unknown.

Implementation RareMerger.txt

Test Cases (Standard as Spark Test its ML features) RareMergerSuite.txt

Unfortunately, I am getting below exception.

Exception in thread "main" java.lang.NumberFormatException: For input string: "__unknown"
    at java.lang.NumberFormatException.forInputString(NumberFormatException.java:65)
    at java.lang.Long.parseLong(Long.java:589)
    at java.lang.Long.parseLong(Long.java:631)
    at org.jpmml.evaluator.TypeUtil.parseInteger(TypeUtil.java:121)
    at org.jpmml.evaluator.TypeUtil.parse(TypeUtil.java:85)
    at org.jpmml.evaluator.TypeUtil.parseOrCast(TypeUtil.java:69)
    at org.jpmml.evaluator.FieldValueUtil.create(FieldValueUtil.java:455)
    at org.jpmml.evaluator.FieldValueUtil.refine(FieldValueUtil.java:512)
    at org.jpmml.evaluator.FieldValueUtil.refine(FieldValueUtil.java:481)
    at org.jpmml.evaluator.ExpressionUtil.evaluate(ExpressionUtil.java:64)
    at org.jpmml.evaluator.ModelEvaluationContext.resolve(ModelEvaluationContext.java:133)
    at org.jpmml.evaluator.EvaluationContext.evaluate(EvaluationContext.java:64)
    at org.jpmml.evaluator.regression.RegressionModelEvaluator.evaluateRegressionTable(RegressionModelEvaluator.java:317)
    at org.jpmml.evaluator.regression.RegressionModelEvaluator.evaluateRegression(RegressionModelEvaluator.java:128)
    at org.jpmml.evaluator.regression.RegressionModelEvaluator.evaluate(RegressionModelEvaluator.java:99)
    at org.jpmml.evaluator.ModelEvaluator.evaluate(ModelEvaluator.java:384)
    at org.apache.spark.ml.pmml.PmmlTries$.main(PmmlTries.scala:81)
    at org.apache.spark.ml.pmml.PmmlTries.main(PmmlTries.scala)    

Testing Code

val conf = new SparkConf().setAppName("OneHotEncoderExample").setMaster("local[*]")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)
    var df = sqlContext.createDataFrame(Seq(
      (1, 3, 1),
      (2, 3, 0),
      (3, 5, 1),
      (4, 5, 1),
      (5, 6, 0),
      (6, 6, 0),
      (7, 6, 0),
      (8, 999999, 1)
      ,(9, 8999, 0)
      ,(10, 89994343, 1)
    )).toDF("id", "category", "label")
    val indexer = new StringIndexer().setInputCol("category").setOutputCol("categoryIndex").setHandleInvalid("keep")
    val encoder = new OneHotEncoder().setInputCol("categoryIndex").setOutputCol("categoryVec").setDropLast(false)
    val assembler = new VectorAssembler().setInputCols(Array("categoryVec","id")).setOutputCol("features")
    val lr = new LinearRegression()
      .setMaxIter(10)
      .setRegParam(0.3)
      .setElasticNetParam(0.8)
    var pipeline = new Pipeline().setStages(Array(indexer, encoder, assembler, lr))
    var model = pipeline.fit(df)
    var pmml = ConverterUtil.toPMML(df.schema, model)
    import org.jpmml.evaluator.ModelEvaluatorFactory
    val modelEvaluatorFactory = ModelEvaluatorFactory.newInstance
    val evaluator = modelEvaluatorFactory.newModelEvaluator(pmml)
    println(new String(ConverterUtil.toPMMLByteArray(df.schema, model), "UTF-8"))
    import org.dmg.pmml.FieldName
    val arguments = new util.LinkedHashMap[FieldName, FieldValue]()
    arguments.put(new FieldName("id"), ContinuousValue.create(DataType.INTEGER, 14))
    arguments.put(new FieldName("category"), CategoricalValue.create(DataType.INTEGER , 399))
    val results = evaluator.evaluate(arguments)
    println(results)

I think you need to add "__unknown" to the label list and handle it as another category level.

        StringIndexerModel transformer = getTransformer();
        Feature feature = encoder.getOnlyFeature(transformer.getInputCol());
        List<String> values = new ArrayList<>(Arrays.asList(transformer.labels()));
                //TODO below line should be checked and used if handleInvalid 'keep' is chosen
        values.add("__unknown");
        DataField dataField = encoder.toCategorical(feature.getName(), values);
        return Collections.<Feature>singletonList(new CategoricalFeature(encoder, dataField));
vruusmann commented 7 years ago

I assume your implementation for StringIndexerModelCOnvertor might throw IllegalArgumentException in case setHandleInvalid("skip") is set.

PMML does not have "skip" functionality. If the PMML engine is asked to score a data record, then the scoring either 1) succeeds or 2) fails with some sort of exception. The "skip" option would mean that the PMML engine doesn't succeed or fail - just consumes a data record.

Unfortunately, I am getting below exception. Exception in thread "main" java.lang.NumberFormatException: For input string: "__unknown"

The constant __unknown is a string value. If you're going to do StringIndexer#setHandleInvalid("keep"), then you should make sure that your column is of string data type.

It's a Apache Spark design decision (see the source code of the StringIndexer transformer). A possible workaround would be that __unknown should be replaced with some other constant for numeric values (eg. -999).

I will need to transform all rare labels to 'RARE', popular labels stay as they are, if training has seen null values then they get translated into category level 'NULL' if not they go to __unknown.

In the above DerivedField element, simply specify the mapMissingTo attribute:

<DerivedField name="handleInvalid(category)" mapMissingTo="NULL" optype="categorical" dataType="string">
</DerivedField>
vruusmann commented 7 years ago

Actually, I think that the RareMerger transform can be represented using the standard StringIndexer transform - no need to extend JPMML-SparkML in any way.

The idea is to manually truncate StringIndexerModel#getLabels() to the desired length (say, keep the first 30 elements of the array, which represent "popular" categories), and specify StringIndexerModel#setHandleInvalid("keep") (which then becomes to represent all other "unpopular" categories).

fatihtekin commented 7 years ago

I do have another issue which is jpmml-evaluator is using com.google.guava:guava:20.0 and spark is using 11.0.2. If I use either of them I am either having Exception in thread "main" java.lang.IllegalAccessError: tried to access method com.google.common.base.Stopwatch.<init>()V from class org.apache.hadoop.mapreduce.lib.input.FileInputFormat

or

Exception in thread "main" java.lang.NoClassDefFoundError: com/google/common/cache/CacheBuilderSpec
    at org.jpmml.evaluator.CacheUtil.<clinit>(CacheUtil.java:112)
    at org.jpmml.evaluator.ModelEvaluator.<clinit>(ModelEvaluator.java:671)
    at org.jpmml.evaluator.ModelEvaluatorFactory.newModelEvaluator(ModelEvaluatorFactory.java:103)
    at org.jpmml.evaluator.ModelEvaluatorFactory.newModelEvaluator(ModelEvaluatorFactory.java:66)
    at org.apache.spark.ml.pmml.PmmlTries$.main(PmmlTries.scala:60)
    at org.apache.spark.ml.pmml.PmmlTries.main(PmmlTries.scala)
Caused by: java.lang.ClassNotFoundException: com.google.common.cache.CacheBuilderSpec
    at java.net.URLClassLoader.findClass(URLClassLoader.java:381)
    at java.lang.ClassLoader.loadClass(ClassLoader.java:424)
    at sun.misc.Launcher$AppClassLoader.loadClass(Launcher.java:335)
    at java.lang.ClassLoader.loadClass(ClassLoader.java:357)
    ... 6 more
fatihtekin commented 7 years ago

I got it sorted by using below in case someone else also needs below is added to the build.sbt file

assemblyShadeRules in assembly := Seq(
  ShadeRule.rename("com.google.guava**" -> "shadeio.@1").inAll
)
fatihtekin commented 7 years ago

When I use below RareMergerModelConverter, I am getting exception at StringIndexerModelConverter. Could you tell me what I am doing wrong?

Exception in thread "main" java.lang.IllegalArgumentException: categoryMerge
    at org.jpmml.converter.PMMLEncoder.toCategorical(PMMLEncoder.java:145)
    at org.jpmml.sparkml.feature.StringIndexerModelConverter.encodeFeatures(StringIndexerModelConverter.java:54)
    at org.jpmml.sparkml.FeatureConverter.registerFeatures(FeatureConverter.java:47)
    at org.jpmml.sparkml.ConverterUtil.toPMML(ConverterUtil.java:75)

Code

import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import org.dmg.pmml.*;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.DOMUtil;
import org.jpmml.converter.Feature;
import org.jpmml.sparkml.FeatureConverter;
import org.jpmml.sparkml.SparkMLEncoder;
import org.apache.spark.ml.pmml.RareMergerModel;
import javax.xml.parsers.DocumentBuilder;

public class RareMergerModelConverter extends FeatureConverter<RareMergerModel> {

    public RareMergerModelConverter(RareMergerModel transformer){
        super(transformer);
    }

    @Override
    public List<Feature> encodeFeatures(SparkMLEncoder encoder){
        RareMergerModel transformer = getTransformer();
        Feature feature = encoder.getOnlyFeature(transformer.getInputCol());
        List<String> columns = Arrays.asList("inputValue", "outputValue");
        InlineTable inlineTable = new InlineTable();
        DocumentBuilder documentBuilder = DOMUtil.createDocumentBuilder();

        for(String popularLabel : transformer.popularLabels()){
            Row row = DOMUtil.createRow(documentBuilder, columns, Arrays.asList(popularLabel, popularLabel));
            inlineTable.addRows(row);
        }

        for(String rareLabel : transformer.rareLabels()){
            Row row = DOMUtil.createRow(documentBuilder, columns, Arrays.asList(rareLabel, transformer.rareLabel()));
            inlineTable.addRows(row);
        }

        MapValues mapValues = new MapValues()
                .addFieldColumnPairs(new FieldColumnPair(feature.getName(), columns.get(0)))
                .setOutputColumn(columns.get(1))
                .setInlineTable(inlineTable)
                .setDefaultValue("__unknown");

        mapValues.setMapMissingTo("__unknown");
        if(transformer.isNullLabelAdded()){
            mapValues.setMapMissingTo("NULL");
        }

        DerivedField derivedField = encoder.createDerivedField(formatName(transformer), OpType.CONTINUOUS, DataType.STRING, mapValues);
        return Collections.<Feature>singletonList(new ContinuousFeature(encoder, derivedField));
    }
}
vruusmann commented 7 years ago

When I use below RareMergerModelConverter, ...

Your RareMergerModelConverter class is close to perfect. Looks like you've successfully managed to figure out a great deal of JPMML-SparkML API design and architecture principles on your own.

One issue that I'm seeing with your code is that the operational type of the derived field should probably be OpType.CATEGORICAL (not OpType.CONTINUOUS), because string values do not have comparison operations (eg. "<", "<=", ">=", ">") defined for them. Following this thought, the generated feature should be an instance of CategoricalFeature (not ContinuousFeature).

... I am getting exception at StringIndexerModelConverter.

The trouble is that JPMML-SparkML assumes that StringIndexerModel will always be the first transformer in the pipeline (for a particular column). This assumption is hard-coded in the form, that the corresponding column name must resolve to a DataField element. In your case, StringIndexerModel is the second transformer in the pipeline (following the RareMergerModel transformer), and the corresponding column name resolves to a DerivedField element.

Class StringIndexerModelConverter should contain a special instanceof check to handle this situation. Essentially, if feature instanceof CategoricalFeature evaluates to true, then there is no need to invoke the #toCategorical(...) logic anymore.

vruusmann commented 7 years ago

So, what's the solution now?

Your RareMerger use case is almost completely handled by the StringIndexer#setHandleInvalid("keep"); the only addition is that you want to be replacing missing values with the __unknown constant.

I would still advise you to stop pursuing this RareMerger path, because it adds unnecessary complexity to your project/application. In PMML, you're supposed to use the MiningField@missingValueReplacement attribute for missing value replacement functionality:

<RegressionModel>
    <MiningSchema>
        <MiningField name="category" missingValueReplacement="__unknown"/>
    </MiningSchema>
</RegressionModel>

You can add/modify/remove PMML elements and attributes using JPMML-Model library (eg. using the very functional and powerful Visitors API). There is no need to change anything in the Apache Spark ML side.

Here's the pattern:

PipelineModel pipelineModel = pipeline.fit(df);

org.dmg.pmml.PMML pmml = ConverterUtil.toPMML(df.schema, pipelineModel);
pmml = performApplicationSpecificCustomizations(pmml); // THIS!
JAXBUtil.marshalPMML(pmml, System.out);

In this fictional performApplicationSpecificCustomizations utility method you can modify the live org.dmg.pmml.PMML class model object in any way you want.

fatihtekin commented 7 years ago

You are absolutely right from the maintenance perspective. Unfortunately, I have to group the rare ones and unknown ones separately so just using stringindexer with the limited group as you suggested won't be enough. Another thing is that I will have to write another transformer again so I prefer to write another converter for now. The other thing I do in RareMerger to decide if the label is popular or not is to check the ratio of the label in all non-null values (using threshold variable).

fatihtekin commented 7 years ago

Is it possible for you to support StringIndexerModelConverter, in case there is another transformer/estimator before? So that I don't need to maintain StringIndexerModelConverter. Maybe also good for future once someone else needed for other estimators that needs to run before StringIndexer

Currently, That is what I am doing below in StringIndexer

        DataField dataField;
        if(feature instanceof CategoricalFeature){
                   dataField = new DataField(new FieldName(transformer.getInputCol()), 
                                                               OpType.CATEGORICAL, DataType.STRING);
                }else{
                    dataField = encoder.toCategorical(feature.getName(), labels);
                }

That is the latest code of RareMergerModelConverter


import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.dmg.pmml.*;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.DOMUtil;
import org.jpmml.converter.Feature;
import org.jpmml.sparkml.FeatureConverter;
import org.jpmml.sparkml.SparkMLEncoder;
import javax.xml.parsers.DocumentBuilder;

public class RareMergerModelConverter extends FeatureConverter<RareMergerModel> {

    public RareMergerModelConverter(RareMergerModel transformer){
        super(transformer);
    }

    @Override
    public List<Feature> encodeFeatures(SparkMLEncoder encoder){
        RareMergerModel transformer = getTransformer();
        Feature feature = encoder.getOnlyFeature(transformer.getInputCol());
                List<String> dataFieldLabels = Stream.of(transformer.rareLabels(), 
                transformer.popularLabels()).flatMap(Stream::of).collect(Collectors.toList());
        List<String> columns = Arrays.asList("inputValue", "outputValue");
        InlineTable inlineTable = new InlineTable();
                DocumentBuilder documentBuilder = DOMUtil.createDocumentBuilder();
        for(String popularLabel : transformer.popularLabels()){
                        Row row = DOMUtil.createRow(documentBuilder, columns, Arrays.asList(popularLabel, popularLabel));
                        inlineTable.addRows(row);
                }
                for(String rareLabel : transformer.rareLabels()){
                        Row row = DOMUtil.createRow(documentBuilder, columns, Arrays.asList(rareLabel, transformer.getRareLabel()));
                inlineTable.addRows(row);
        }
        MapValues mapValues = new MapValues()
                .addFieldColumnPairs(new FieldColumnPair(feature.getName(), columns.get(0)))
                .setOutputColumn(columns.get(1))
                .setInlineTable(inlineTable)
                                .setDefaultValue("__unknown")
                                .setMapMissingTo("__unknown");
        dataFieldLabels.add("__unknown");
        if(transformer.getNullToString() && transformer.isNullLabelAdded()){
            mapValues.setMapMissingTo(transformer.getNullLabel());
            dataFieldLabels.add(transformer.getNullLabel());
        }

        DerivedField derivedField = encoder.createDerivedField(formatName(transformer), OpType.CATEGORICAL, DataType.STRING, mapValues);
        return Collections.<Feature>singletonList(new CategoricalFeature(encoder, derivedField, dataFieldLabels));
    }
}