JuliaPy / PythonCall.jl

Python and Julia in harmony.
https://juliapy.github.io/PythonCall.jl/stable/
MIT License
800 stars 64 forks source link

How to pass Julia function to Python? #142

Closed Octogonapus closed 2 years ago

Octogonapus commented 2 years ago

I need to pass some Julia functions as callbacks to a Python function. I've scoured the documentation but I can't figure out how to pass callbacks to Python.

My code:

module IotTest

const awsiot = PythonCall.pynew()
const mqtt_connection_builder = PythonCall.pynew()
const awscrt = PythonCall.pynew()
const mqtt = PythonCall.pynew()

function __init__()
    CondaPkg.add_pip("awsiotsdk")
    PythonCall.pycopy!(awsiot, pyimport("awsiot"))
    PythonCall.pycopy!(mqtt_connection_builder, pyimport("awsiot.mqtt_connection_builder"))
    PythonCall.pycopy!(awscrt, pyimport("awscrt"))
    PythonCall.pycopy!(mqtt, pyimport("awscrt.mqtt"))
end

function on_message_received(topic, payload, dup, qos, retain; kwargs...)
    @info "Received message" topic payload
    received_count += 1
    count_down(received_all_latch)
end

function main()
     mqtt_connection = mqtt_connection_builder.mtls_from_path(;
        # args omitted
        on_connection_interrupted,
        on_connection_resumed,
    )

    subscribe_future, packet_id = mqtt_connection.subscribe(;
        topic = message_topic,
        qos = mqtt.QoS.AT_LEAST_ONCE,
        callback = on_message_received,
    )
end
end

The error I get:

ERROR: Python: TypeError: Julia: MethodError: no method matching length(::typeof(Main.IoTTest.on_message_received))
Closest candidates are:
  length(!Matched::Union{Base.KeySet, Base.ValueIterator}) at ~/julia-1.7.2/share/julia/base/abstractdict.jl:58
  length(!Matched::Union{LinearAlgebra.Adjoint{T, S}, LinearAlgebra.Transpose{T, S}} where {T, S}) at ~/julia-1.7.2/share/julia/stdlib/v1.7/LinearAlgebra/src/adjtrans.jl:171
  length(!Matched::Union{Tables.AbstractColumns, Tables.AbstractRow}) at ~/.julia/packages/Tables/PxO1m/src/Tables.jl:175
  ...
Python stacktrace:
 [1] __len__
   @ /home/salmon/.julia/packages/PythonCall/XgP8G/src/jlwrap/any.jl:168:32
 [2] subscribe
   @ awscrt.mqtt .../dev/.CondaPkg/env/lib/python3.10/site-packages/awscrt/mqtt.py:502
Stacktrace:
 [1] pythrow()
   @ PythonCall ~/.julia/packages/PythonCall/XgP8G/src/err.jl:94
 [2] errcheck
   @ ~/.julia/packages/PythonCall/XgP8G/src/err.jl:10 [inlined]
 [3] pycallargs
   @ ~/.julia/packages/PythonCall/XgP8G/src/abstract/object.jl:154 [inlined]
 [4] pycall(::PythonCall.Py; kwargs::Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:topic, :qos, :callback), Tuple{String, PythonCall.Py, typeof(Main.IoTTest.on_message_received)}}})
   @ PythonCall ~/.julia/packages/PythonCall/XgP8G/src/abstract/object.jl:165
 [5] #_#11
   @ ~/.julia/packages/PythonCall/XgP8G/src/Py.jl:360 [inlined]
 [6] main()
   @ Main.IoTTest .../dev/iot_test.jl:111

This is the Python code of my dependency:

def subscribe(self, topic, qos, callback=None):
        future = Future()
        packet_id = 0

        if callback:
            def callback_wrapper(topic, payload, dup, qos, retain):
                try:
                    callback(topic=topic, payload=payload, dup=dup, qos=QoS(qos), retain=retain)
                except TypeError:
                    # This callback used to have fewer args.
                    # Try again, passing only those those args, to cover case where
                    # user function failed to take forward-compatibility **kwargs.
                    callback(topic=topic, payload=payload)
        else:
            callback_wrapper = None

        def suback(packet_id, topic, qos, error_code):
            if error_code:
                future.set_exception(awscrt.exceptions.from_code(error_code))
            else:
                qos = _try_qos(qos)
                if qos is None:
                    future.set_exception(SubscribeError(topic))
                else:
                    future.set_result(dict(
                        packet_id=packet_id,
                        topic=topic,
                        qos=qos,
                    ))

        try:
            assert callable(callback) or callback is None
            assert isinstance(qos, QoS)
            packet_id = _awscrt.mqtt_client_connection_subscribe(
                self._binding, topic, qos.value, callback_wrapper, suback)
        except Exception as e:
            future.set_exception(e)

        return future, packet_id

I can change if callback: to if callback is not None: for the same effect, though 1) I don't want to start maintaining patches for my dependencies and 2) this method of passing callbacks leads to segfaults later on when they are invoked, so I think I am just doing something wrong.

Julia Version 1.7.2
Commit bf53498635 (2022-02-06 15:21 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Core(TM) i9-9900K CPU @ 3.60GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-12.0.1 (ORCJIT, skylake)

PythonCall v0.8.0
cjdoris commented 2 years ago

Ok so the approach looks fine, but indeed there is a bug in getting the truthiness of a wrapped Julia value, which I'll fix.

You can also pass pyfunc(f) instead of the function f. The difference is that the arguments passed to pyfunc(f) are always Py (which is normally appropriate for callback functions) whereas the arguments to f are converted first. As a handy side effect, pyfunc(f) shouldn't have the truthy bug.

I'll need a MWE if you want me to look at the segfaults.

Octogonapus commented 2 years ago

Thanks, the pyfunc method works.

As for the segfault, I'm happy to provide the full code, though it relies on you having certain AWS resources. Not sure if you want to provision those. If you don't, I'm also happy to grant you access to my resources so that you can test with them. I'd want to transfer the details in a secure way, of course. Let me know if either option interests you.

cjdoris commented 2 years ago

Closing as the main issue is fixed. Feel free to open an issue if you're still having those segfaults.