SciML / juliatorch

Convert Julia functions to PyTorch autograd functions
MIT License
5 stars 0 forks source link

Improved Type Transparency Python-side (Non-Blocking) #14

Open azane opened 10 months ago

azane commented 10 months ago

Many different types end up going through the python function being differentiated. First, it might be evaluated sometimes as a standard python function involving torch tensors (not as part of this package, but related to #15). Then, within a JuliaFunction execution, it first sees standard np.ndarrays, and then juliacall.ArrayValues of Dual numbers.

Vectorized operations fail on these juliacall.ArrayValue's (with Julia: MethodError: no method matching errors) where they succeed in numpy and torch. To work with Tensors, ndarrays, and the backward pass involving ArrayValues of Duals, we must, for example:

m = torch.tensor([1., 1.], requires_grad=True)
x = torch.tensor([2., 2.], requires_grad=True)
b = torch.tensor([3., 3.], requires_grad=True)

def f(v):
    if isinstance(v, juliacall.ArrayValue):
        v = v.to_numpy()
    m = v[0, :]
    x = v[1, :]
    b = v[2, :]
    return (m * x + b).sum()

y = JuliaFunction.apply(f, torch.cat(torch.atleast_2d(m, x, b)))
y.backward()

This is okay, but more issues arise when dealing with juliacall.RealValue and juliacall.ArrayValue in combination. The example below fails on the backward pass with Julia: MethodError: no method matching and requests julia's elementwise broadcasting syntax use dot syntax: array .+ scalar.

m = torch.tensor([1., 1.], requires_grad=True)
x = torch.tensor(2., requires_grad=True)
b = torch.tensor(3., requires_grad=True)

def f(v):
    m = v[:2]
    x = v[2]
    b = v[3]
    return (m * x + b).sum()

y = JuliaFunction.apply(f, torch.cat(torch.atleast_1d(m.ravel(), x, b)))
y.backward()

This can be partially addressed by inserting a conditional v = v.to_numpy() as before, but this results in type(m) being an np.ndarray with dtype == object, and x and b being juliacall.RealValue Duals. When m * x happens in this case, np.asarray(x) is implicity called, which results in ValueError: setting an array element with a sequence. The requested array would exceed the maximum number of dimension of 32. I.e. it seems to attempt to make an array with the full buffer of the Dual?

This can be made to run with the following modifications, where we ensure that scalars stay in vectors of length 1.

m = torch.tensor([1., 1.], requires_grad=True)
x = torch.tensor(2., requires_grad=True)
b = torch.tensor(3., requires_grad=True)

def f(v):
    if isinstance(v, juliacall.ArrayValue):
        v = v.to_numpy()
    m = v[:2]
    x = v[2, None]
    b = v[3, None]
    return (m * x + b).sum()

y = JuliaFunction.apply(f, torch.cat(torch.atleast_1d(m.ravel(), x, b)))
y.backward()

As tagged in the title, this is non-blocking, but the workarounds are quite involved/nuanced, and I haven't discussed here the subtleties needed when using the same function for both numpy and torch.

Ideally, everything could be converted to torch tensors before the python function ever sees it — this would probably work fine on the forward pass in combination with torch.nograd, but obviously torch does not support Dual as a dtype.

TODO

It's also not clear to me whether these the modifications can jit compile (haven't dug into that yet), or if they are still required when you jit compile.

azane commented 10 months ago

Re: note at end of issue above, I doubt these are required when jit compiling the dynamics — see my note at the end of #15 and in the comments for why working hard to solve this may not play to the strenghts of using diffeqpy as a backend for chirho.

**Turns out, these modifications (or the one discussed below) are required when jit compiling.

azane commented 9 months ago

At least where chirho is concerned, I've implemented a robust (so far 🤷 ), if clunky, workaround for when different julia types are passed through a user-defined python function. TLDR is to wrap the julia "thing" in a python object that 1) obscures the underlying julia thing from numpy, preventing introspection and unpacking to >32 dims (which is the numpy limit) and 2) forwards all the relevant math dunder methods to the underlying julia entity. Arrays of JuliaThingWrappers will have dtype=np.object_, but because the math dunder methods are all forwarded, standard vectorized math works.

My preference was to find some numpy-internal solution to disabling the introspection of the julia types (and just treating them as objects), but after a few hours of research I was only able to find details where others were also unable to find that ability within numpy itself. Perhaps there's an easier way to prevent this introspection from the julia side?

This approach won't be fast, because vector math of arrays with objects can fall back to c/c++. It does work, however, and lets the user ignore the nuances of the underlying julia type.

Though I don't show this in the example below, this wrapping strategy solves the same problem when jit compiling the dynamics with diffeqpy (problem being the unpacking of julia symbolics to >32 ndim arrays). The test linked here lists a variety of dynamics functions that jit-compile and pass gradcheck when the underlying julia symbolics are wrapped in a JuliaThingWrapper before being exposed to the user-supplied dynamics function. Note that the pre- and post-processing machinery is happening internally and not shown in the test.

For chirho's current purposes, the overhead of wrapping arrays and the slow speed of vector math on ndarrays of objects only appears during jit compilation, so we aren't too concerned with optimizing it right now.

Dunder methods that get forwarded ```python DUNDERS = [ '__abs__', '__add__', '__bool__', '__ceil__', '__eq__', '__float__', '__floor__', '__floordiv__', '__ge__', '__gt__', '__invert__', '__le__', '__lshift__', '__lt__', '__mod__', '__mul__', '__ne__', '__neg__', '__or__', '__pos__', '__pow__', '__radd__', '__rand__', '__reversed__', '__rfloordiv__', '__rlshift__', '__rmod__', '__rmul__', '__ror__', '__round__', '__rpow__', '__rrshift__', '__rshift__', '__rsub__', '__rtruediv__', '__rxor__', '__sub__', '__truediv__', '__trunc__', '__xor__', ] ```
`JuliaThingWrapper` ```python class JuliaThingWrapper: """ This wrapper just acts as a pass-through to the julia object, but obscures the underlying memory buffer of the julia thing (realvalue, symbolic, dual number, etc.). This prevents numpy from introspecting the julia thing as a sequence with a large number of dimensions (exceeding the ndim 32 limit). Unfortunately, even with a dtype of np.object_, this introspection still occurs. The issue of casting to a numpy array can also be addressed by first creating an empty array of dtype object, and then filling it with the julia thing (as occurs in unwrap_array below), but this fails to generalize well in cases where numpy is doing the casting itself. As of now, this seems the most robust solution. Note that numpy arrays of objects will, internally, use the dunder math methods of the objects they contain when performing math operations. This is not fast, but for our purposes is fine b/c the main application here involves julia symbolics only during jit compilation. As such, the point of this class is to wrap scalar valued julia things only so that we can use numpy arrays of julia things. """ def __init__(self, julia_thing): self.julia_thing = julia_thing @staticmethod def wrap_array(arr: np.ndarray): return np.vectorize(JuliaThingWrapper)(arr) @staticmethod def unwrap_array(arr: np.ndarray, out: Optional[np.ndarray] = None): # As discussed in docstring, we cannot simply vectorize a deconstructor because numpy will try to internally # cast the unwrapped_julia things into an array, which fails due to introspection triggering the ndim 32 limit. # Instead, we have to manually assign each element of the array. This is slow, but only occurs during jit # compilation for our use case. if out is None: out = np.empty(arr.shape, dtype=np.object_) for idx, v in np.ndenumerate(arr): out[idx] = v.julia_thing return out @classmethod def _forward_dunders(cls): # Forward all the math related dunder methods to the underlying julia thing. for method_name in DUNDERS: cls._make_dunder(method_name) def __repr__(self): return f"JuliaThingWrapper({self.julia_thing})" @classmethod def _make_dunder(cls, method_name): """ Automate the definition of dunder methods involving the underlying julia things. Note that just intercepting getattr doesn't work here because dunder method calls skip getattr, and getattribute is fairly complex to work with. """ def dunder(self: JuliaThingWrapper, *args): # Retrieve the underlying dunder method of the julia thing. method = getattr(self.julia_thing, method_name) if not args: # E.g. __neg__, __pos__, __abs__ don't have an "other" result = method() if result is NotImplemented: raise NotImplementedError(f"Operation {method_name} is not implemented for {self.julia_thing}.") else: if len(args) != 1: raise ValueError("Only one argument is supported for automated dunder method dispatch.") other, = args if isinstance(other, np.ndarray): if other.ndim == 0: # In certain cases, that TODO need to be sussed out (maybe numpy internal nuance) the # julia_thing is a scalar array of a JuliaThingWrapper, so we need to further unwrap the # scalar array to get at the JuliaThingWrapper (and, in turn, the julia_thing). other = other.item() else: # Wrap self in an array and recurse back through numpy broadcasting. This is required when a # JuliaThingWrapper "scalar" is involved in an operation with a numpy array on the right. scalar_array_self = np.array(self) scalar_array_self_attr = getattr(scalar_array_self, method_name) return scalar_array_self_attr(other) # Extract the underlying julia thing. if isinstance(other, JuliaThingWrapper): other = other.julia_thing # Perform the operation using the corresponding method of the Julia object result = method(other) if result is NotImplemented: raise NotImplementedError(f"Operation {method_name} is not implemented for" f" {self.julia_thing} and {other}.") # Rewrap the return. return JuliaThingWrapper(result) setattr(cls, method_name, dunder) # noinspection PyProtectedMember JuliaThingWrapper._forward_dunders() ```

So, following the examples cited above, we can:

m = torch.tensor([1., 1.], requires_grad=True)
x = torch.tensor(2., requires_grad=True)
b = torch.tensor(3., requires_grad=True)

def f(v):
    if isinstance(v, juliacall.ArrayValue):
        v = v.to_numpy()

    # Wrap so that standard numpy ops work.
    v = JuliaThingWrapper.wrap_array(v)

    m = v[:2]
    x = v[2]
    b = v[3]

    ret = (m * x + b).sum()
    # Unwrap what is now a wrapped scalar. If this were an array, the in-place unwrap_array would be used.
    ret = ret.julia_thing

    print("JuliaThing", type(ret), ret)

    return ret

y = JuliaFunction.apply(f, torch.cat(torch.atleast_1d(m.ravel(), x, b)))
y.backward()

print(m.grad, x.grad, b.grad)

Which gives:

JuliaThing <class 'float'> 10.0
JuliaThing <class 'juliacall.RealValue'> Dual{ForwardDiff.Tag{var"#3#4"{Py, PyArray{Float64, 0, true, true, Float64}}, Float32}}(10.0,2.0,2.0,2.0,2.0)
tensor([2., 2.]) tensor(2.) tensor(2.)

My thinking is that something like this wrapping machinery could be applied to the dual number array values before being handed back to the user. The same wrapping set up could also wrap julia symbolics during jit compilation in diffeqpy.

As mentioned, this is working for us for now, so no rush here, but wanted to flag this is a possible approach to addressing the type transparency issues.

LilithHafner commented 9 months ago

Thank you!

azane commented 9 months ago

On further investigation/testing, the above works for dunder methods but fails for numpy's universal functions. I've addressed this by splitting functionality into the dunder forwarding and ufunc forwarding (where elementwise ufuncs are forwarded on to the julia function of the same name, both for a "scalar" JuliaThingWrapper and every element of a an array of "julia things" — this requires a slightly different because, unlike with dunder methods, numpy doesn't utilize __array_ufunc__ of the elements in an array).

Original Dunder Method Forwarding ```python class _DunderedJuliaThingWrapper: def __init__(self, julia_thing): self.julia_thing = julia_thing @classmethod def _forward_dunders(cls): # Forward all the math related dunder methods to the underlying julia thing. dunders = [ '__abs__', '__add__', '__bool__', '__ceil__', '__eq__', '__float__', '__floor__', '__floordiv__', '__ge__', '__gt__', '__invert__', '__le__', '__lshift__', '__lt__', '__mod__', '__mul__', '__ne__', '__neg__', '__or__', '__pos__', '__pow__', '__radd__', '__rand__', '__reversed__', '__rfloordiv__', '__rlshift__', '__rmod__', '__rmul__', '__ror__', '__round__', '__rpow__', '__rrshift__', '__rshift__', '__rsub__', '__rtruediv__', '__rxor__', '__sub__', '__truediv__', '__trunc__', '__xor__', ] for method_name in dunders: cls._make_dunder(method_name) @classmethod def _make_dunder(cls, method_name): """ Automate the definition of dunder methods involving the underlying julia things. Note that just intercepting getattr doesn't work here because dunder method calls skip getattr, and getattribute is fairly complex to work with. """ def dunder(self: _DunderedJuliaThingWrapper, *args): # Retrieve the underlying dunder method of the julia thing. method = getattr(self.julia_thing, method_name) if not args: # E.g. __neg__, __pos__, __abs__ don't have an "other" result = method() if result is NotImplemented: raise NotImplementedError(f"Operation {method_name} is not implemented for {self.julia_thing}.") else: if len(args) != 1: raise ValueError("Only one argument is supported for automated dunder method dispatch.") other, = args if isinstance(other, np.ndarray): if other.ndim == 0: # In certain cases, that TODO need to be sussed out (maybe numpy internal nuance) the # julia_thing is a scalar array of a JuliaThingWrapper, so we need to further unwrap the # scalar array to get at the JuliaThingWrapper (and, in turn, the julia_thing). other = other.item() else: # Wrap self in an array and recurse back through numpy broadcasting. This is required when a # JuliaThingWrapper "scalar" is involved in an operation with a numpy array on the right. scalar_array_self = np.array(self) scalar_array_self_attr = getattr(scalar_array_self, method_name) return scalar_array_self_attr(other) # Extract the underlying julia thing. if isinstance(other, _DunderedJuliaThingWrapper): other = other.julia_thing # Perform the operation using the corresponding method of the Julia object result = method(other) if result is NotImplemented: raise NotImplementedError(f"Operation {method_name} is not implemented for" f" {self.julia_thing} and {other}.") # Rewrap the return. return JuliaThingWrapper(result) setattr(cls, method_name, dunder) # noinspection PyProtectedMember _DunderedJuliaThingWrapper._forward_dunders() ```
Exposed JuliaThingWrapper, now forwarding ufuncs ```python class JuliaThingWrapper(_DunderedJuliaThingWrapper): """ This wrapper just acts as a pass-through to the julia object, but obscures the underlying memory buffer of the julia thing (realvalue, symbolic, dual number, etc.). This prevents numpy from introspecting the julia thing as a sequence with a large number of dimensions (exceeding the ndim 32 limit). Unfortunately, even with a dtype of np.object_, this introspection still occurs. The issue of casting to a numpy array can also be addressed by first creating an empty array of dtype object, and then filling it with the julia thing (as occurs in unwrap_array below), but this fails to generalize well in cases where numpy is doing the casting itself. As of now, this seems the most robust solution. Note that numpy arrays of objects will, internally, use the dunder math methods of the objects they contain when performing math operations. This is not fast, but for our purposes is fine b/c the main application here involves julia symbolics only during jit compilation. As such, the point of this class is to wrap scalar valued julia things only so that we can use numpy arrays of julia things. This class also handles the forwarding of numpy universal functions like sin, exp, log, etc. to the corresopnding julia version. See __array_ufunc__ for more details. """ @staticmethod def wrap_array(arr: np.ndarray): regular_array = np.vectorize(JuliaThingWrapper)(arr) # Because we need to forward numpy ufuncs to julia, return regular_array.view(_JuliaThingWrapperArray) @staticmethod def unwrap_array(arr: np.ndarray, out: Optional[np.ndarray] = None): # As discussed in docstring, we cannot simply vectorize a deconstructor because numpy will try to internally # cast the unwrapped_julia things into an array, which fails due to introspection triggering the ndim 32 limit. # Instead, we have to manually assign each element of the array. This is slow, but only occurs during jit # compilation for our use case. if out is None: out = np.empty(arr.shape, dtype=np.object_) for idx, v in np.ndenumerate(arr): out[idx] = v.julia_thing return out def __repr__(self): return f"JuliaThingWrapper({self.julia_thing})" def _jl_ufunc(self, ufunc): # Try to grab something from the ufunc_name = ufunc.__name__ try: jlfunc = getattr(jl, ufunc_name) except AttributeError: # when __array_ufunc__ fails to resolve, it returns NotImplemented, so this follows that pattern. return NotImplemented result = jlfunc(self.julia_thing) return JuliaThingWrapper(result) def __array_ufunc__(self, ufunc, method, *args, **kwargs): """ Many numpy functions (like sin, exp, log, etc.) are so-called "universal functions" that don't correspond to a standard dunder method. To handle these, we need to dispatch the julia version of the function on the underlying julia thing. This is done by overriding __array_ufunc__ and forwarding the call to the jl function operating on the underlying julia thing, assuming that the corresponding dunder method hasn't already been defined. """ # First try to evaluate the default ufunc (this will dispatch first to dunder methods, like __abs__). ret = _default_ufunc(ufunc, method, *args, **kwargs) if ret is not NotImplemented: return ret # Otherwise, try to dispatch the ufunc to the underlying julia thing. return self._jl_ufunc(ufunc) ```
_JuliaThingWrapperArray, handling ufunc forwarding at array level ```python class _JuliaThingWrapperArray(np.ndarray): """ Subclassing the numpy array in order to translate ufunc calls to julia equivalent calls at the array level ( rather than the element level). This is required because numpy doesn't defer to the __array_ufunc__ method of the underlying object for arrays of dtype object. """ def __array_ufunc__(self, ufunc, method, *args, **kwargs): # First, try to resolve the ufunc in the standard manner (this will dispatch first to dunder methods). ret = _default_ufunc(ufunc, method, *args, **kwargs) if ret is not NotImplemented: return ret # Otherwise, because numpy doesn't defer to the ufuncs of the underlying objects when being called on an array # of objects (unlike how it behaves with dunder-definable ops), iterate manually and do so here. result = _JuliaThingWrapperArray(self.shape, dtype=object) for idx, v in np.ndenumerate(self): assert isinstance(v, JuliaThingWrapper) result[idx] = v._jl_ufunc(ufunc) return result ```
Machinery for switching back and forth between default and intercepted ufuncs ```python def _default_ufunc(ufunc, method, *args, **kwargs): f = getattr(ufunc, method) # Numpy's behavior changes if __array_ufunc__ is defined at all, i.e. a super() call is insufficient to # capture the default behavior as if no __array_ufunc__ were involved. The way to do this is to create # a standard np.ndarray view of the underlying memory, and then call the ufunc on that. nargs = (cast_as_lacking_array_ufunc(x) for x in args) try: ret = f(*nargs, **kwargs) return cast_as_having_array_func(ret) except TypeError as e: # If the exception reports anything besides non-implementation of the ufunc, then re-raise. if f"no callable {ufunc.__name__} method" not in str(e): raise # Otherwise, just return NotImplemented in keeping with standard __array_ufunc__ behavior. else: return NotImplemented # These functions handle casting back and forth from entities that have custom behavior for numpy ufuncs, and those # that don't. @singledispatch def cast_as_lacking_array_ufunc(v): return v @cast_as_lacking_array_ufunc.register def _(v: _JuliaThingWrapperArray): return v.view(np.ndarray) @cast_as_lacking_array_ufunc.register def _(v: JuliaThingWrapper): return _DunderedJuliaThingWrapper(v.julia_thing) @singledispatch def cast_as_having_array_func(v): return v @cast_as_having_array_func.register def _(v: np.ndarray): return v.view(_JuliaThingWrapperArray) @cast_as_having_array_func.register def _(v: _DunderedJuliaThingWrapper): return JuliaThingWrapper(v.julia_thing) ```

With all that in place, we can:

def f(v):
    if isinstance(v, juliacall.ArrayValue):
        v = v.to_numpy()

    # Wrap so that standard numpy ops work.
    v = JuliaThingWrapper.wrap_array(v)

    # Using ufuncs on arrays.
    print("Array", type(v))
    v = np.exp(np.log(v))

    m = v[:2]
    x = v[2]
    b = v[3]

    # Using ufuncs on an individual element.
    print("Element", type(b))
    b = (b / np.abs(b)) * np.abs(b)

    wrapped_ret = (m * x + b).sum()

    # Unwrap what is now a wrapped scalar. If this were an array, the in-place unwrap_array would be used.
    ret = wrapped_ret.julia_thing

    print("JuliaThing", type(ret), ret)

    return ret

m = torch.tensor([1., 1.])
x = torch.tensor(2.)
b = torch.tensor(3.)

mxb_cat = torch.cat(torch.atleast_1d(m, x, b)).double().requires_grad_()

print("Forward")
y = JuliaFunction.apply(f, mxb_cat)

print("\nBackward")
y.backward()

print("\nGrad")
print(mxb_cat.grad)

# torch.autograd.gradcheck(JuliaFunction.apply, (f, mxb_cat))

Which gives

Forward
Array <class '__main__._JuliaThingWrapperArray'>
Element <class '__main__.JuliaThingWrapper'>
JuliaThing <class 'float'> 10.0

Backward
Array <class '__main__._JuliaThingWrapperArray'>
Element <class '__main__.JuliaThingWrapper'>
JuliaThing <class 'juliacall.RealValue'> Dual{ForwardDiff.Tag{var"#3#4"{Py, PyArray{Float64, 0, true, true, Float64}}, Float64}}(10.0,2.0,2.0,2.0,2.0)

Grad
tensor([2., 2., 2., 2.], dtype=torch.float64)
azane commented 9 months ago

Also note that I haven't exhaustively tested all ufuncs and all of the forwarded dunder methods. There are bound to be edge cases here where the automated dispatching won't work out of the box.

LilithHafner commented 9 months ago

Possibly related: https://github.com/JuliaPy/PythonCall.jl/issues/390