pabloferz / DLPack.jl

Julia interface for dlpack
MIT License
48 stars 4 forks source link

Performance optimization for DifferentiationInterfaceJAX #42

Open gdalle opened 1 week ago

gdalle commented 1 week ago

Hi @pabloferz!

Following your kind invitation, here's a prototype of what I would like to achieve in DifferentiationInterfaceJAX.jl: call a function defined in Python (fp) on Julia arrays (xj) with minimal overhead. I'm curious if there is any faster way to do things?

using BenchmarkTools
using DLPack: share, from_dlpack
using PythonCall

jax = pyimport("jax")
jnp = pyimport("jax.numpy")

fp(xp) = jnp.sum(jnp.square(xp))

function fj(xj)
    xp = share(xj, jax.dlpack.from_dlpack)
    return fp(xp)
end

function fjp!(xp_scratch, xj_scratch, xj)
    # assume xp_scratch and xj_scratch are aliased
    copyto!(xj_scratch, xj)
    return fp(xp_scratch)
end

xj = Float32.(1:10^5)
xp = share(xj, jax.dlpack.from_dlpack)

xj_scratch = Vector{Float32}(undef, 10^5)
xp_scratch = share(xj_scratch, jax.dlpack.from_dlpack)

Benchmark results:

julia> @btime fj($xj)
  88.455 μs (56 allocations: 2.52 KiB)
Python: Array(3.3333832e+14, dtype=float32)

julia> @btime fp($xp)
  14.754 μs (22 allocations: 368 bytes)
Python: Array(3.3333832e+14, dtype=float32)

julia> @btime fjp!($xp_scratch, $xj_scratch, $xj)
  28.563 μs (22 allocations: 368 bytes)
Python: Array(3.3333832e+14, dtype=float32)
gdalle commented 3 days ago

Here's a benchmark for the other half of the overhead: moving back from JAX tensors to Julia arrays:

using BenchmarkTools
using DLPack: share, from_dlpack
using PythonCall

jax = pyimport("jax")
jnp = pyimport("jax.numpy")

fp(xp) = jnp.square(xp)

function fj(xp)
    yp = fp(xp)
    yj = from_dlpack(yp)
    return yj
end

xj = Float32.(1:10^5)
xp = share(xj, jax.dlpack.from_dlpack)
julia> @btime fp($xp);
  37.098 μs (11 allocations: 184 bytes)

julia> @btime fj($xp);
  49.526 μs (30 allocations: 2.02 KiB)