xydrolase / shap4j

JVM bindings for the SHAP library
MIT License
6 stars 0 forks source link

explainer.shapValues(values, false) is crashing after calling 1000 times #9

Open sudheerprem opened 2 years ago

sudheerprem commented 2 years ago

Calling explainer.shapValues(values, false) is crashing after calling 1000 times in a for loop.

sudheerprem commented 2 years ago

Attached the Crash log:

A fatal error has been detected by the Java Runtime Environment:

  SIGSEGV (0xb) at pc=0x000000010a8860ec, pid=29549, tid=0x0000000000000d03

 JRE version: Java(TM) SE Runtime Environment (8.0_151-b12) (build 1.8.0_151-b12)
 Java VM: Java HotSpot(TM) 64-Bit Server VM (25.151-b12 mixed mode bsd-amd64 compressed oops)
 Problematic frame:
 C  [libjniTreeShap.dylib+0x40ec]  dense_tree_path_dependent(TreeEnsemble const&, ExplanationDataset const&, double*, double (*)(double, double))+0x36c

 Failed to write core dump. Core dumps have been disabled. To enable core dumping, try "ulimit -c unlimited" before starting Java again

 An error report file with more information is saved as:
/shap4jtest/hs_err_pid29549.log

 If you would like to submit a bug report, please visit:
   http://bugreport.java.com/bugreport/crash.jsp
 The crash happened outside the Java Virtual Machine in native code.
 See problematic frame for where to report the bug.
xydrolase commented 2 years ago

@sudheerprem Hi, thanks for submitting the issue. Is it possible for you to provide the shap4j model and the values for reproducing the error?

liuchenustb commented 1 year ago

I had the same problem with Linux system

liuchenustb commented 1 year ago
截屏2022-09-02 下午2 58 55
xydrolase commented 1 year ago

@liuchenustb Thanks for reporting the issue. If possible, can you provide the model file and the vector for reproducing this issue?

adrianbouza commented 4 months ago

Hi, I've started using this repo and encountered the same problem.

I think I've found a possible solution.

When you call the explainer in a loop you are reusing the TreeEnsemble object. This object is created with the fromBytes method that loads the shap4j explainer. In this method there is a pointer that I think it is not been handled correctly:

        // allocate a native memory block, and copy the java array to the native memory block
        BytePointer rawDataPtr = new BytePointer(rawData.length);
        rawDataPtr.put(rawData, 0, rawData.length);

That pointer is never closed and it isn't handled inside a PointerScope.

I've modified the TreeEnsemble and TreeExplainer classes to fix this. In the TreeEnsemble I've created a PointerScope that handles the rawDataPtr and on the TreeExplainer I added a close method that closes the PointerScope of the TreeEnsemble.

TreeEnsemble:

package shap4j.shap;

import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.ByRef;
import org.bytedeco.javacpp.annotation.Const;
import org.bytedeco.javacpp.annotation.Platform;

import java.io.File;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Files;

@Platform(include="shap4j/shap/tree_shap.h")
public class TreeEnsemble extends Pointer {
    static {
        Loader.load();
    }

    private TreeEnsemble() {
        allocate();
    }
    private native void allocate();
    private PointerScope scope;
    private TreeEnsemble(
            IntPointer children_left, IntPointer children_right, IntPointer children_default, IntPointer features,
            DoublePointer thresholds, DoublePointer values, DoublePointer node_sample_weights,
            int max_depth, int tree_limit, DoublePointer base_offset, int max_nodes, int num_outputs, PointerScope scope
    ) {
        this.scope=scope;
        allocate(children_left, children_right, children_default, features, thresholds, values, node_sample_weights,
                max_depth, tree_limit, base_offset, max_nodes, num_outputs);
    }
    private native void allocate(
            IntPointer children_left, IntPointer children_right, IntPointer children_default, IntPointer features,
            DoublePointer thresholds, DoublePointer values, DoublePointer node_sample_weights,
            int max_depth, int tree_limit, DoublePointer base_offset, int max_nodes, int num_outputs
    );

    public native void get_tree(@ByRef TreeEnsemble tree, @Const int i);

    // FIXME: figure out how to map this to the C++ allocate() function
    // public native void allocate(int tree_limit_in, int max_nodes_in, int num_outputs_in);

    // methods mapping to the struct fields of TreeEnsemble; all setters are private to prevent those fields from being
    // updated in JVM;
    public native int num_outputs(); private native void num_outputs(int setter);
    public native int tree_limit(); private native void tree_limit(int setter);
    public native int max_nodes(); private native void max_nodes(int setter);
    native IntPointer children_left(); private native void children_left(IntPointer setter);
    native IntPointer children_right(); private native void children_right(IntPointer setter);

    public int getChildrenLeft(int treeIndex, int nodeIndex) {
        return children_left().get(treeIndex * max_nodes() + nodeIndex);
    }

    public int getChildrenRight(int treeIndex, int nodeIndex) {
        return children_right().get(treeIndex * max_nodes() + nodeIndex);
    }

    public native void free();

    private static IntPointer getIntPointer(BytePointer base, int position, int numElements) {
        IntPointer ptr = new IntPointer(base);
        return ptr.position(position).limit(position + numElements);
    }

    private static DoublePointer getDoublePointer(BytePointer base, int position, int numElements) {
        DoublePointer ptr = new DoublePointer(base);
        return ptr.position(position).limit(position + numElements);
    }

   public static TreeEnsemble fromBytes(byte[] rawData) {
          ByteBuffer buffer = ByteBuffer.wrap(rawData).order(ByteOrder.nativeOrder());

          byte[] magicBytes = new byte[4];
          buffer.get(magicBytes, 0, 4);
          int version = buffer.getInt();

          assert new String(magicBytes).equals("SHAP");
          assert version == 1;

          int numTrees = buffer.getInt();
          int maxDepth = buffer.getInt();
          int maxNodes = buffer.getInt();
          int numOutputs = buffer.getInt();
          int offsetIntArrays = buffer.getInt();
          int offsetDoubleArrays = buffer.getInt();
          double baseOffset = buffer.getDouble();
          DoublePointer ptrBaseOffset = new DoublePointer(1);
          ptrBaseOffset.put(baseOffset);

          int numElements = numTrees * maxNodes;
          PointerScope mScope = new PointerScope();

          // allocate a native memory block, and copy the java array to the native memory block
          BytePointer rawDataPtr = new BytePointer(rawData.length);
          rawDataPtr.put(rawData, 0, rawData.length);
          mScope.attach(rawDataPtr);
          // create pointers pointing to different sections of the memory block (allocated through rawDataPtr)
          IntPointer childrenLeft = getIntPointer(rawDataPtr, offsetIntArrays >> 2, numElements);
          IntPointer childrenRight = getIntPointer(rawDataPtr, (int) childrenLeft.limit(), numElements);
          IntPointer childrenDefault = getIntPointer(rawDataPtr, (int) childrenRight.limit(), numElements);
          IntPointer features = getIntPointer(rawDataPtr, (int) childrenDefault.limit(), numElements);

          DoublePointer thresholds = getDoublePointer(rawDataPtr, offsetDoubleArrays >> 3, numElements);
          DoublePointer values = getDoublePointer(rawDataPtr, (int) thresholds.limit(), numElements * numOutputs);
          DoublePointer nodeSampleWeight = getDoublePointer(rawDataPtr, (int) values.limit(), numElements);

          return new TreeEnsemble(
                  childrenLeft, childrenRight, childrenDefault, features, thresholds, values, nodeSampleWeight,
                  maxDepth, numTrees, ptrBaseOffset, maxNodes, numOutputs,mScope
          );
      }

      @Override
      public void close() {
          // release all pointers attached to the current scope;
          if (scope != null) scope.close();
          super.close();
      }
}

TreeExplainer:

package shap4j;

import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.PointerScope;
import shap4j.shap.ExplanationDataset;
import shap4j.shap.TreeEnsemble;
import shap4j.shap.TreeShap;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;

/**
 * A SHAP explainer using Tree SHAP algorithms to explain the output of tree ensemble models.
 *
 * @see <a href="https://github.com/slundberg/shap/blob/master/shap/explainers/tree.py">Python interface for TreeExplainer</a>
 */
public class TreeExplainer {
    private static final int TREE_PATH_DEPENDENT_FEATURE = 1;
    private static final int IDENTITY_TRANSFORM = 0;

    private TreeEnsemble treeEnsemble;

    private TreeExplainer(TreeEnsemble ensemble) {
        this.treeEnsemble = ensemble;
    }

   ...

    public void close(){
        if (this.treeEnsemble!=null) this.treeEnsemble.close();
    }
}

With these changes I can reuse the explainer in a loop without getting memory issues and when I finish the loop I just close the explainer.

I am no sure if this is the correct solution but it has fixed the problem for me. Hope it helps someone.