Open gdalle opened 1 week ago
Here's a benchmark for the other half of the overhead: moving back from JAX tensors to Julia arrays:
using BenchmarkTools
using DLPack: share, from_dlpack
using PythonCall
jax = pyimport("jax")
jnp = pyimport("jax.numpy")
fp(xp) = jnp.square(xp)
function fj(xp)
yp = fp(xp)
yj = from_dlpack(yp)
return yj
end
xj = Float32.(1:10^5)
xp = share(xj, jax.dlpack.from_dlpack)
julia> @btime fp($xp);
37.098 μs (11 allocations: 184 bytes)
julia> @btime fj($xp);
49.526 μs (30 allocations: 2.02 KiB)
Hi @pabloferz!
Following your kind invitation, here's a prototype of what I would like to achieve in DifferentiationInterfaceJAX.jl: call a function defined in Python (
fp
) on Julia arrays (xj
) with minimal overhead. I'm curious if there is any faster way to do things?Benchmark results: