Open sudheerprem opened 2 years ago
Attached the Crash log:
A fatal error has been detected by the Java Runtime Environment:
SIGSEGV (0xb) at pc=0x000000010a8860ec, pid=29549, tid=0x0000000000000d03
JRE version: Java(TM) SE Runtime Environment (8.0_151-b12) (build 1.8.0_151-b12)
Java VM: Java HotSpot(TM) 64-Bit Server VM (25.151-b12 mixed mode bsd-amd64 compressed oops)
Problematic frame:
C [libjniTreeShap.dylib+0x40ec] dense_tree_path_dependent(TreeEnsemble const&, ExplanationDataset const&, double*, double (*)(double, double))+0x36c
Failed to write core dump. Core dumps have been disabled. To enable core dumping, try "ulimit -c unlimited" before starting Java again
An error report file with more information is saved as:
/shap4jtest/hs_err_pid29549.log
If you would like to submit a bug report, please visit:
http://bugreport.java.com/bugreport/crash.jsp
The crash happened outside the Java Virtual Machine in native code.
See problematic frame for where to report the bug.
@sudheerprem Hi, thanks for submitting the issue. Is it possible for you to provide the shap4j model and the values for reproducing the error?
I had the same problem with Linux system
@liuchenustb Thanks for reporting the issue. If possible, can you provide the model file and the vector for reproducing this issue?
Hi, I've started using this repo and encountered the same problem.
I think I've found a possible solution.
When you call the explainer in a loop you are reusing the TreeEnsemble object. This object is created with the fromBytes method that loads the shap4j explainer. In this method there is a pointer that I think it is not been handled correctly:
// allocate a native memory block, and copy the java array to the native memory block
BytePointer rawDataPtr = new BytePointer(rawData.length);
rawDataPtr.put(rawData, 0, rawData.length);
That pointer is never closed and it isn't handled inside a PointerScope.
I've modified the TreeEnsemble and TreeExplainer classes to fix this. In the TreeEnsemble I've created a PointerScope that handles the rawDataPtr and on the TreeExplainer I added a close method that closes the PointerScope of the TreeEnsemble.
TreeEnsemble:
package shap4j.shap;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.ByRef;
import org.bytedeco.javacpp.annotation.Const;
import org.bytedeco.javacpp.annotation.Platform;
import java.io.File;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Files;
@Platform(include="shap4j/shap/tree_shap.h")
public class TreeEnsemble extends Pointer {
static {
Loader.load();
}
private TreeEnsemble() {
allocate();
}
private native void allocate();
private PointerScope scope;
private TreeEnsemble(
IntPointer children_left, IntPointer children_right, IntPointer children_default, IntPointer features,
DoublePointer thresholds, DoublePointer values, DoublePointer node_sample_weights,
int max_depth, int tree_limit, DoublePointer base_offset, int max_nodes, int num_outputs, PointerScope scope
) {
this.scope=scope;
allocate(children_left, children_right, children_default, features, thresholds, values, node_sample_weights,
max_depth, tree_limit, base_offset, max_nodes, num_outputs);
}
private native void allocate(
IntPointer children_left, IntPointer children_right, IntPointer children_default, IntPointer features,
DoublePointer thresholds, DoublePointer values, DoublePointer node_sample_weights,
int max_depth, int tree_limit, DoublePointer base_offset, int max_nodes, int num_outputs
);
public native void get_tree(@ByRef TreeEnsemble tree, @Const int i);
// FIXME: figure out how to map this to the C++ allocate() function
// public native void allocate(int tree_limit_in, int max_nodes_in, int num_outputs_in);
// methods mapping to the struct fields of TreeEnsemble; all setters are private to prevent those fields from being
// updated in JVM;
public native int num_outputs(); private native void num_outputs(int setter);
public native int tree_limit(); private native void tree_limit(int setter);
public native int max_nodes(); private native void max_nodes(int setter);
native IntPointer children_left(); private native void children_left(IntPointer setter);
native IntPointer children_right(); private native void children_right(IntPointer setter);
public int getChildrenLeft(int treeIndex, int nodeIndex) {
return children_left().get(treeIndex * max_nodes() + nodeIndex);
}
public int getChildrenRight(int treeIndex, int nodeIndex) {
return children_right().get(treeIndex * max_nodes() + nodeIndex);
}
public native void free();
private static IntPointer getIntPointer(BytePointer base, int position, int numElements) {
IntPointer ptr = new IntPointer(base);
return ptr.position(position).limit(position + numElements);
}
private static DoublePointer getDoublePointer(BytePointer base, int position, int numElements) {
DoublePointer ptr = new DoublePointer(base);
return ptr.position(position).limit(position + numElements);
}
public static TreeEnsemble fromBytes(byte[] rawData) {
ByteBuffer buffer = ByteBuffer.wrap(rawData).order(ByteOrder.nativeOrder());
byte[] magicBytes = new byte[4];
buffer.get(magicBytes, 0, 4);
int version = buffer.getInt();
assert new String(magicBytes).equals("SHAP");
assert version == 1;
int numTrees = buffer.getInt();
int maxDepth = buffer.getInt();
int maxNodes = buffer.getInt();
int numOutputs = buffer.getInt();
int offsetIntArrays = buffer.getInt();
int offsetDoubleArrays = buffer.getInt();
double baseOffset = buffer.getDouble();
DoublePointer ptrBaseOffset = new DoublePointer(1);
ptrBaseOffset.put(baseOffset);
int numElements = numTrees * maxNodes;
PointerScope mScope = new PointerScope();
// allocate a native memory block, and copy the java array to the native memory block
BytePointer rawDataPtr = new BytePointer(rawData.length);
rawDataPtr.put(rawData, 0, rawData.length);
mScope.attach(rawDataPtr);
// create pointers pointing to different sections of the memory block (allocated through rawDataPtr)
IntPointer childrenLeft = getIntPointer(rawDataPtr, offsetIntArrays >> 2, numElements);
IntPointer childrenRight = getIntPointer(rawDataPtr, (int) childrenLeft.limit(), numElements);
IntPointer childrenDefault = getIntPointer(rawDataPtr, (int) childrenRight.limit(), numElements);
IntPointer features = getIntPointer(rawDataPtr, (int) childrenDefault.limit(), numElements);
DoublePointer thresholds = getDoublePointer(rawDataPtr, offsetDoubleArrays >> 3, numElements);
DoublePointer values = getDoublePointer(rawDataPtr, (int) thresholds.limit(), numElements * numOutputs);
DoublePointer nodeSampleWeight = getDoublePointer(rawDataPtr, (int) values.limit(), numElements);
return new TreeEnsemble(
childrenLeft, childrenRight, childrenDefault, features, thresholds, values, nodeSampleWeight,
maxDepth, numTrees, ptrBaseOffset, maxNodes, numOutputs,mScope
);
}
@Override
public void close() {
// release all pointers attached to the current scope;
if (scope != null) scope.close();
super.close();
}
}
TreeExplainer:
package shap4j;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.PointerScope;
import shap4j.shap.ExplanationDataset;
import shap4j.shap.TreeEnsemble;
import shap4j.shap.TreeShap;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
/**
* A SHAP explainer using Tree SHAP algorithms to explain the output of tree ensemble models.
*
* @see <a href="https://github.com/slundberg/shap/blob/master/shap/explainers/tree.py">Python interface for TreeExplainer</a>
*/
public class TreeExplainer {
private static final int TREE_PATH_DEPENDENT_FEATURE = 1;
private static final int IDENTITY_TRANSFORM = 0;
private TreeEnsemble treeEnsemble;
private TreeExplainer(TreeEnsemble ensemble) {
this.treeEnsemble = ensemble;
}
...
public void close(){
if (this.treeEnsemble!=null) this.treeEnsemble.close();
}
}
With these changes I can reuse the explainer in a loop without getting memory issues and when I finish the loop I just close the explainer.
I am no sure if this is the correct solution but it has fixed the problem for me. Hope it helps someone.
Calling explainer.shapValues(values, false) is crashing after calling 1000 times in a for loop.