tensorflow / java

Java bindings for TensorFlow
Apache License 2.0
785 stars 193 forks source link

Add C API extension for custom gradient functions #511

Closed karllessard closed 6 months ago

karllessard commented 6 months ago

This PR introduces a new C API extension directly edited from TF Java sources to enable capabilities not being exposed by the official TensorFlow C API. This simplifies the generation of the Java bindings for the C API when using JavaCPP since we now avoid parsing complex, internal and unstable C++ data structures.

The first capability enabled by this extension is the custom gradient Java functions, excluded so far in the bazelcism development branches.

The PR also introduces a series of small minor changes.

karllessard commented 6 months ago

@saudet if you can also review the native and JavaCPP part, that would be great, thank you

saudet commented 6 months ago

@saudet if you can also review the native and JavaCPP part, that would be great, thank you

Looks alright, although I put "native" stuff in src/main/resources to make it available to end users so they can use it for their own native libraries, but I don't think this is going to happen here so, I guess it doesn't matter

karllessard commented 6 months ago

Ok @Craigacp , pushed my changes. I've also switched to the TFJ_ prefix finally