tensorflow / java

Java bindings for TensorFlow
Apache License 2.0
832 stars 202 forks source link

Custom gradient registration is broken in Windows #486

Open karllessard opened 1 year ago

karllessard commented 1 year ago

Since we've upgrade to TF 2.10.1, the JVM crashes on Windows when attempting to register custom gradients. This has been observed when running the unit tests. Right now, custom gradient registration in Windows has been disabled until this issue is fixed. Seems to be related to a wrong pointer returned by TF_OperationName.

Here's some traces:

---------------  T H R E A D  ---------------

Current thread (0x000002382b662000):  JavaThread "main" [_thread_in_native, id=5220, stack(0x00000045fcd00000,0x00000045fce00000)]

Stack: [0x00000045fcd00000,0x00000045fce00000],  sp=0x00000045fcdfac90,  free space=1003k
Native frames: (J=compiled Java code, j=interpreted, Vv=VM code, C=native code)
C  [jnitensorflow.dll+0x4c030]

Java frames: (J=compiled Java code, j=interpreted, Vv=VM code)
j  org.tensorflow.internal.c_api.NameMap.erase(Lorg/bytedeco/javacpp/BytePointer;)J+0
j  org.tensorflow.GraphOperationBuilder.finishDangerousGradient(Lorg/tensorflow/internal/c_api/TF_Graph;Lorg/tensorflow/internal/c_api/TF_OperationDescription;)Lorg/tensorflow/internal/c_api/TF_Operation;+36
j  org.tensorflow.GraphOperationBuilder.build()Lorg/tensorflow/GraphOperation;+34
j  org.tensorflow.GraphOperationBuilder.build()Lorg/tensorflow/Operation;+1
j  org.tensorflow.op.core.Constant.create(Lorg/tensorflow/op/Scope;Lorg/tensorflow/types/family/TType;)Lorg/tensorflow/op/core/Constant;+37
j  org.tensorflow.op.core.Constant.scalarOf(Lorg/tensorflow/op/Scope;F)Lorg/tensorflow/op/core/Constant;+7
j  org.tensorflow.op.Ops.constant(F)Lorg/tensorflow/op/core/Constant;+5
j  org.tensorflow.CustomGradientTest.lambda$testCustomGradient$1(Lorg/tensorflow/op/Ops;Lorg/tensorflow/op/nn/NthElement$Inputs;Ljava/util/List;)Ljava/util/List;+13
j  org.tensorflow.CustomGradientTest$$Lambda$394.call(Lorg/tensorflow/op/Ops;Lorg/tensorflow/op/RawOpInputs;Ljava/util/List;)Ljava/util/List;+6
j  org.tensorflow.op.TypedGradientAdapter.call(Lorg/tensorflow/internal/c_api/TF_Scope;Lorg/tensorflow/internal/c_api/NativeOperation;Lorg/tensorflow/internal/c_api/NativeOutputVector;Lorg/tensorflow/internal/c_api/NativeOutputVector;)Lorg/tensorflow/internal/c_api/NativeStatus;+125
v  ~StubRoutines::call_stub
j  org.tensorflow.internal.c_api.global.tensorflow.TF_AddGradientsWithPrefix(Lorg/tensorflow/internal/c_api/TF_Graph;Ljava/lang/String;Lorg/tensorflow/internal/c_api/TF_Output;ILorg/tensorflow/internal/c_api/TF_Output;ILorg/tensorflow/internal/c_api/TF_Output;Lorg/tensorflow/internal/c_api/TF_Status;Lorg/tensorflow/internal/c_api/TF_Output;)V+0
j  org.tensorflow.Graph.addGradients(Lorg/tensorflow/internal/c_api/TF_Graph;Ljava/lang/String;[Lorg/tensorflow/internal/c_api/TF_Operation;[I[Lorg/tensorflow/internal/c_api/TF_Operation;[I[Lorg/tensorflow/internal/c_api/TF_Operation;[I)[Ljava/lang/Object;+162
j  org.tensorflow.Graph.addGradients(Ljava/lang/String;[Lorg/tensorflow/Output;[Lorg/tensorflow/Output;[Lorg/tensorflow/Output;)[Lorg/tensorflow/Output;+221
j  org.tensorflow.Graph.addGradients(Lorg/tensorflow/Output;[Lorg/tensorflow/Output;)[Lorg/tensorflow/Output;+12
j  org.tensorflow.CustomGradientTest.testCustomGradient()V+101

The generated JNI code:

JNIEXPORT jlong JNICALL Java_org_tensorflow_internal_c_1api_NameMap_erase__Lorg_bytedeco_javacpp_BytePointer_2(JNIEnv* env, jobject obj, jobject arg0) {
    std::unordered_map<tensorflow::string,tensorflow::Node*>* ptr = (std::unordered_map<tensorflow::string,tensorflow::Node*>*)jlong_to_ptr(env->GetLongField(obj, JavaCPP_addressFID));
    if (ptr == NULL) {
        env->ThrowNew(JavaCPP_getClass(env, 7), "This pointer address is NULL.");
        return 0;
    }
    jlong position = env->GetLongField(obj, JavaCPP_positionFID);
    ptr += position;
    signed char* ptr0 = arg0 == NULL ? NULL : (signed char*)jlong_to_ptr(env->GetLongField(arg0, JavaCPP_addressFID));
    jlong size0 = arg0 == NULL ? 0 : env->GetLongField(arg0, JavaCPP_limitFID);
    void* owner0 = JavaCPP_getPointerOwner(env, arg0);
    jlong position0 = arg0 == NULL ? 0 : env->GetLongField(arg0, JavaCPP_positionFID);
    ptr0 += position0;
    size0 -= position0;
    StringAdapter< char > adapter0(ptr0, size0, owner0);
    jlong rarg = 0;
    jlong rval = ptr->erase((std::basic_string< char >&)adapter0);
    rarg = (jlong)rval;
    signed char* rptr0 = adapter0;
    jlong rsize0 = (jlong)adapter0.size;
    void* rowner0 = adapter0.owner;
    if (rptr0 != ptr0) {
        JavaCPP_initPointer(env, arg0, rptr0, rsize0, rowner0, &StringAdapter< char >::deallocate);
    } else {
        env->SetLongField(arg0, JavaCPP_limitFID, rsize0 + position0);
    }
    return rarg;
}

Also, this binding seems to receive a special treatment from our end: https://github.com/tensorflow/java/blob/455fc731ece7aba751e7ec7ee1cd06dff3a1bfe0/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java#L391

Originally posted by @karllessard in https://github.com/tensorflow/java/issues/484#issuecomment-1378232331