juliatorch lets you convert Julia functions to PyTorch autograd.Functions, automatically differentiating the julia functions in the process.

To install juliatorch, use Python 3.11 and pip:

pip install git+

Example usage

>>> from juliatorch import JuliaFunction
>>> import juliacall, torch
>>> f = juliacall.Main.seval("f(x) = exp.(-x .^ 2)")
>>> py_f = lambda x: f(x)
>>> x = torch.randn(3, 3, dtype=torch.double, requires_grad=True)
>>> JuliaFunction.apply(f, x)
tensor([[0.8583, 0.9999, 0.9712],
        [0.7043, 0.1852, 0.6042],
        [0.9968, 0.8472, 0.9913]], dtype=torch.float64,
>>> from torch.autograd import gradcheck
>>> gradcheck(JuliaFunction.apply, (py_f, x), eps=1e-6, atol=1e-4)

Using Julia's differential equation solvers in PyTorch

from juliatorch import JuliaFunction

import juliacall, torch

jl = juliacall.Main.seval

jl('import Pkg')
jl('using DifferentialEquations')

f = jl("""
function f(u0)
    ode_f(u, p, t) = -u
    tspan = (0.0, 1.0)
    prob = ODEProblem(ode_f, u0, tspan)
    sol = DifferentialEquations.solve(prob)
    return sol.u[end]

# 0.36787959342751697
# 0.7357591870280833
# 2.0000000004703966

x = torch.randn(3, 3, dtype=torch.double, requires_grad=True)

print(JuliaFunction.apply(f, x) / x)
# tensor([[0.3679, 0.3679, 0.3679],
#         [0.3679, 0.3679, 0.3679],
#         [0.3679, 0.3679, 0.3679]], dtype=torch.float64, grad_fn=<DivBackward0>)

from torch.autograd import gradcheck
py_f = lambda x: f(x)
print(gradcheck(JuliaFunction.apply, (py_f, x), eps=1e-6, atol=1e-4))
# True (wow, I honestly didn't expect that to work. Up to now
#       I'd only been using trivial Julia functions but it worked
#       on a full differential equation solver on the first try)

Fitting a harmonic oscillator's parameter and initial conditions to match observations

This example uses diffeqpy to solve the differential equations and pytorch to optimize the parameters.

from juliatorch import JuliaFunction
from diffeqpy import de
import juliacall, torch
jl = juliacall.Main.seval

# Define the ODE kernel
def ode_f(du, u, p, t):
    x = u[0]
    v = u[1]
    dx = v
    dv = -p * x
    du[0] = dx
    du[1] = dv

# Use diffeqpy to solve the differential equation for given parameters
def solve(parameters):
    x0, v0, p = parameters
    tspan = (0.0, 1.0)
    # Why not just use `de.ODEProblem`? That would pass gradcheck but fail in the
    # optimization loop. See
    prob = de.seval("ODEProblem{true, SciMLBase.FullSpecialize}")(ode_f, [x0, v0], tspan, p)
    return de.solve(prob)

# Extract the desired results
def solve_and_query(parameters):
    sol = solve(parameters)
    return de.hcat(sol(.5), sol(1.0))

print(solve_and_query([1, 2, 3]))
# [1.5274653930969104 0.9791625277649281; -0.023690980408490492 -2.0306945154435274]

x = torch.randn(3, dtype=torch.double, requires_grad=True)
print(JuliaFunction.apply(solve_and_query, x))
# tensor([[-0.4471, -0.3979],
#         [ 0.3155, -0.1103]], dtype=torch.float64,
#        grad_fn=<JuliaFunctionBackward>)

# Verify that autograd through solve_and_query is correct
from torch.autograd import gradcheck
print(gradcheck(JuliaFunction.apply, (solve_and_query, x), eps=1e-6, atol=1e-4))
# True

parameters = torch.tensor([1.0, 1.0, 1.0], requires_grad=True)
observations = torch.tensor([[ 0.4301,  0.3577], # Hardcode for consistency
                       [-0.3892, -1.6914]])
weights = torch.tensor([[1.0, 1.0], [1.0, 0.0]])
n_steps = 10000
for learning_rate in [.03, .01, .003]:
    optimizer = torch.optim.SGD([parameters], lr=learning_rate)
    for i in range(n_steps):
        solution = JuliaFunction.apply(solve_and_query, parameters) # Solve the ODE
        loss = torch.norm(weights * (solution - observations)) # Define the loss function
        loss.backward() # Back-propagate the loss through all differentiable torch variables
        optimizer.step() # Update the parameters using the gradients computed by back-propagation

# It's worth rechecking that the gradient is still accurate because of Goodhart's Law:
print(gradcheck(JuliaFunction.apply, (solve_and_query, parameters), eps=1e-2, atol=1e-2))
# True

# tensor([ 0.7748, -1.0569, -2.3015], requires_grad=True)
# tensor(0.0195, dtype=torch.float64, grad_fn=<LinalgVectorNormBackward0>)

# Plot the solution
from matplotlib import pyplot as plt
import numpy
def plot(parameters, observations):
    sol = solve(parameters.detach().numpy())
    t = numpy.linspace(0,1,100)
    u = sol(t)
    plt.plot(t,u[0,:],label="simulated x")
    plt.plot(t,u[1,:],label="simulated v")
    plt.plot([.5,1.0],observations[0,:],"o",label="observed x")
    plt.plot([.5],observations[1,0],"o",label="observed v")

plot(parameters, observations)


