Open azane opened 10 months ago
Re: note at end of issue above, I doubt these are required when jit compiling the dynamics — see my note at the end of #15 and in the comments for why working hard to solve this may not play to the strenghts of using diffeqpy as a backend for chirho.
**Turns out, these modifications (or the one discussed below) are required when jit compiling.
At least where chirho is concerned, I've implemented a robust (so far 🤷 ), if clunky, workaround for when different julia types are passed through a user-defined python function. TLDR is to wrap the julia "thing" in a python object that 1) obscures the underlying julia thing from numpy, preventing introspection and unpacking to >32 dims (which is the numpy limit) and 2) forwards all the relevant math dunder methods to the underlying julia entity. Arrays of JuliaThingWrapper
s will have dtype=np.object_
, but because the math dunder methods are all forwarded, standard vectorized math works.
My preference was to find some numpy-internal solution to disabling the introspection of the julia types (and just treating them as objects), but after a few hours of research I was only able to find details where others were also unable to find that ability within numpy itself. Perhaps there's an easier way to prevent this introspection from the julia side?
This approach won't be fast, because vector math of arrays with objects can fall back to c/c++. It does work, however, and lets the user ignore the nuances of the underlying julia type.
Though I don't show this in the example below, this wrapping strategy solves the same problem when jit compiling the dynamics with diffeqpy (problem being the unpacking of julia symbolics to >32 ndim arrays). The test linked here lists a variety of dynamics functions that jit-compile and pass gradcheck when the underlying julia symbolics are wrapped in a JuliaThingWrapper
before being exposed to the user-supplied dynamics function. Note that the pre- and post-processing machinery is happening internally and not shown in the test.
For chirho's current purposes, the overhead of wrapping arrays and the slow speed of vector math on ndarrays of objects only appears during jit compilation, so we aren't too concerned with optimizing it right now.
So, following the examples cited above, we can:
m = torch.tensor([1., 1.], requires_grad=True)
x = torch.tensor(2., requires_grad=True)
b = torch.tensor(3., requires_grad=True)
def f(v):
if isinstance(v, juliacall.ArrayValue):
v = v.to_numpy()
# Wrap so that standard numpy ops work.
v = JuliaThingWrapper.wrap_array(v)
m = v[:2]
x = v[2]
b = v[3]
ret = (m * x + b).sum()
# Unwrap what is now a wrapped scalar. If this were an array, the in-place unwrap_array would be used.
ret = ret.julia_thing
print("JuliaThing", type(ret), ret)
return ret
y = JuliaFunction.apply(f, torch.cat(torch.atleast_1d(m.ravel(), x, b)))
y.backward()
print(m.grad, x.grad, b.grad)
Which gives:
JuliaThing <class 'float'> 10.0
JuliaThing <class 'juliacall.RealValue'> Dual{ForwardDiff.Tag{var"#3#4"{Py, PyArray{Float64, 0, true, true, Float64}}, Float32}}(10.0,2.0,2.0,2.0,2.0)
tensor([2., 2.]) tensor(2.) tensor(2.)
My thinking is that something like this wrapping machinery could be applied to the dual number array values before being handed back to the user. The same wrapping set up could also wrap julia symbolics during jit compilation in diffeqpy.
As mentioned, this is working for us for now, so no rush here, but wanted to flag this is a possible approach to addressing the type transparency issues.
Thank you!
On further investigation/testing, the above works for dunder methods but fails for numpy's universal functions. I've addressed this by splitting functionality into the dunder forwarding and ufunc forwarding (where elementwise ufuncs are forwarded on to the julia function of the same name, both for a "scalar" JuliaThingWrapper
and every element of a an array of "julia things" — this requires a slightly different because, unlike with dunder methods, numpy doesn't utilize __array_ufunc__ of the elements in an array).
With all that in place, we can:
def f(v):
if isinstance(v, juliacall.ArrayValue):
v = v.to_numpy()
# Wrap so that standard numpy ops work.
v = JuliaThingWrapper.wrap_array(v)
# Using ufuncs on arrays.
print("Array", type(v))
v = np.exp(np.log(v))
m = v[:2]
x = v[2]
b = v[3]
# Using ufuncs on an individual element.
print("Element", type(b))
b = (b / np.abs(b)) * np.abs(b)
wrapped_ret = (m * x + b).sum()
# Unwrap what is now a wrapped scalar. If this were an array, the in-place unwrap_array would be used.
ret = wrapped_ret.julia_thing
print("JuliaThing", type(ret), ret)
return ret
m = torch.tensor([1., 1.])
x = torch.tensor(2.)
b = torch.tensor(3.)
mxb_cat = torch.cat(torch.atleast_1d(m, x, b)).double().requires_grad_()
print("Forward")
y = JuliaFunction.apply(f, mxb_cat)
print("\nBackward")
y.backward()
print("\nGrad")
print(mxb_cat.grad)
# torch.autograd.gradcheck(JuliaFunction.apply, (f, mxb_cat))
Which gives
Forward
Array <class '__main__._JuliaThingWrapperArray'>
Element <class '__main__.JuliaThingWrapper'>
JuliaThing <class 'float'> 10.0
Backward
Array <class '__main__._JuliaThingWrapperArray'>
Element <class '__main__.JuliaThingWrapper'>
JuliaThing <class 'juliacall.RealValue'> Dual{ForwardDiff.Tag{var"#3#4"{Py, PyArray{Float64, 0, true, true, Float64}}, Float64}}(10.0,2.0,2.0,2.0,2.0)
Grad
tensor([2., 2., 2., 2.], dtype=torch.float64)
Also note that I haven't exhaustively tested all ufuncs and all of the forwarded dunder methods. There are bound to be edge cases here where the automated dispatching won't work out of the box.
Possibly related: https://github.com/JuliaPy/PythonCall.jl/issues/390
Many different types end up going through the python function being differentiated. First, it might be evaluated sometimes as a standard python function involving torch tensors (not as part of this package, but related to #15). Then, within a
JuliaFunction
execution, it first sees standardnp.ndarray
s, and thenjuliacall.ArrayValue
s ofDual
numbers.Vectorized operations fail on these
juliacall.ArrayValue
's (withJulia: MethodError: no method matching
errors) where they succeed in numpy and torch. To work withTensor
s,ndarray
s, and the backward pass involvingArrayValue
s ofDual
s, we must, for example:This is okay, but more issues arise when dealing with
juliacall.RealValue
andjuliacall.ArrayValue
in combination. The example below fails on the backward pass withJulia: MethodError: no method matching
and requests julia's elementwise broadcasting syntaxuse dot syntax: array .+ scalar
.This can be partially addressed by inserting a conditional
v = v.to_numpy()
as before, but this results intype(m)
being annp.ndarray
withdtype == object
, andx
andb
beingjuliacall.RealValue
Dual
s. Whenm * x
happens in this case,np.asarray(x)
is implicity called, which results inValueError: setting an array element with a sequence. The requested array would exceed the maximum number of dimension of 32
. I.e. it seems to attempt to make an array with the full buffer of theDual
?This can be made to run with the following modifications, where we ensure that scalars stay in vectors of length 1.
As tagged in the title, this is non-blocking, but the workarounds are quite involved/nuanced, and I haven't discussed here the subtleties needed when using the same function for both numpy and torch.
Ideally, everything could be converted to torch tensors before the python function ever sees it — this would probably work fine on the forward pass in combination with
torch.nograd
, but obviously torch does not supportDual
as a dtype.TODO
It's also not clear to me whether these the modifications can jit compile (haven't dug into that yet), or if they are still required when you jit compile.