jpy-consortium / jpy

Apache License 2.0
68 stars 16 forks source link

Add jpy.byte_buffer() function #112

Closed jmao-denver closed 6 months ago

jmao-denver commented 9 months ago

Fixes #111 Dependent on by https://github.com/deephaven/deephaven-core/pull/4936

  1. byte_buffer() utility function to create a Java direct ByteBuffer wrapper that shares the underlying buffer of a Python buffer object (that implements the buffer protocol).
  2. when calling a Java method that takes a ByteBuffer argument, a Java ByteBuffer wrapper can be passed in.
  3. when calling a Java method with the last argument being a variadic one of Java ByteBuffer type, a sequence of Java ByteBuffer wrappers made from calling byte_buffer() util function can be accepted.

Note, Java methods that receive ByteBuffer arguments are considered to be borrowing these buffers, not owning them, and they are only guaranteed to be safe to access for the duration of the methods. So if these Python buffers are to be used in Java beyond these method calls, it falls to the users to keep the Python objects from being GCed after the Java method calls finish or make copies of the buffers before returning from the method calls.

import jpyutil
jpyutil.init_jvm()
import jpy

def check_jbb(jbb):
    print(jbb.toString())
    print(f"isReadOnly: {jbb.isReadOnly()}")
    print("before scanning...")
    print(f"remaining: {jbb.remaining()}")
    print(f"position: {jbb.position()}")
    for i in range(jbb.remaining()):
        print(i, jbb.get())
    print("after scanning...")
    print(jbb.toString())
    print(f"remaining: {jbb.remaining()}")
    print(f"position: {jbb.position()}")

ba = b'abc'
jbb = jpy.byte_buffer(ba)
check_jbb(jbb)

import pyarrow as pa

data = [
    pa.array([1, 2, 3, 4]),
    pa.array(['foo', 'bar', 'baz', None]),
    pa.array([True, None, False, True])
]

batch = pa.record_batch(data, names=['f0', 'f1', 'f2'])
sink = pa.BufferOutputStream()
with pa.ipc.new_stream(sink, batch.schema) as writer:
   for i in range(5):
      writer.write_batch(batch)

buf = sink.getvalue()
jbb = jpy.byte_buffer(buf)
check_jbb(jbb)

buf = batch.schema.serialize()
jbb = jpy.byte_buffer(buf)
check_jbb(jbb)
devinrsmith commented 9 months ago

Another small wrinkle for the future; java 21 has long for capacity.

devinrsmith commented 6 months ago

Here's a script I've been using:

import jpy
from contextlib import contextmanager

_JByteBuffer = jpy.get_type('java.nio.ByteBuffer')
_JArrowToTableConverter = jpy.get_type('io.deephaven.extensions.barrage.util.ArrowToTableConverter')

@contextmanager
def jpy_flags(flags):
    orig_flags = jpy.diag.flags
    jpy.diag.flags = flags
    try:
        yield
    finally:
        jpy.diag.flags = orig_flags

def buffer_protocol():
    return jpy.byte_buffer(b'hello world')

def j_byte_buffer():
    return _JByteBuffer.allocate(42)

def j_buffer():
    # return type is java.nio.Buffer *not* ByteBuffer
    # impl is ByteBuffer.allocateDirect(size)
    return _JArrowToTableConverter.myBuffer(43)

def j_object():
    # return type is java.lang.Object 
    # impl is new Object()
    return _JArrowToTableConverter.newObject()

def print_info(name, x):
    print(f"{name}: type(x)={type(x)}" + "\n")
    print(f"{name}: x={x}" + "\n")

def create_print_del(name, fn):
    my_obj = fn()
    with jpy_flags(jpy.diag.F_OFF):
        print_info(name, my_obj)
    del my_obj

with jpy_flags(jpy.diag.F_MEM | jpy.diag.F_TYPE):
    create_print_del('j_object', j_object)
    create_print_del('buffer_protocol', buffer_protocol)
    create_print_del('j_byte_buffer', j_byte_buffer)
    create_print_del('j_buffer', j_buffer)

I've added some public static methods to ArrowToTableConverter just because it was a place to put some java logic.

devinrsmith commented 6 months ago

https://github.com/devinrsmith/jpy/tree/111-DirectByteBuffer-support is the branch where I've added some logging.