datasig-ac-uk / RoughPy

Toolbox for working with streaming data as rough paths in Python
https://roughpy.org
BSD 3-Clause "New" or "Revised" License
29 stars 4 forks source link

Jax numpy arrays and NaNs in the log-signature. #25

Closed Benjamin-Walker closed 1 year ago

Benjamin-Walker commented 1 year ago

The log-signature of a path depends significantly on whether that array was first a jax numpy array that is converted to a numpy array or if the array was originally a numpy array. This can lead to NaNs being present in the log_signature. There is also an element of randomness in this behaviour, with the same script demonstrating different behaviour on subsequent runs, with the nans sometimes not being present. However, II have never observed agreement between the three examples below, which I was expecting to all return (roughly) the same output.

I believe this may actually be an issue with how Jax handles numpy conversion. I apologise if this turns out to be a Jax problem instead of a roughpy problem.

The following script demonstrates this behaviour.

import jax.numpy as jnp
import numpy as np
import roughpy as rp

X = np.array([[-0.25860816, -0.36977386, 0.6619457, -0.50442713, 0.08028925, -1.06701028],
              [-0.26208243, 0.22464547, 0.39521545, -0.62663144, -0.34344956, -1.67293704],
              [-0.55824, -0.19376263, 0.86616075, -0.58314389, -0.69254208, -1.53291035],
              [-0.52306908, -0.09234464, 1.17564034, -0.7388621, -0.91333717, -1.50844121],
              [-0.80696738, -0.09417236, 0.75135314, -1.20548987, -1.42038512, -1.86834741],
              [-0.6642682, -0.12166289, 1.04914618, -1.01415539, -1.58841276, -2.54356289]])

Y = jnp.array([[-0.25860816, -0.36977386, 0.6619457, -0.50442713, 0.08028925, -1.06701028],
               [-0.26208243, 0.22464547, 0.39521545, -0.62663144, -0.34344956, -1.67293704],
               [-0.55824, -0.19376263, 0.86616075, -0.58314389, -0.69254208, -1.53291035],
               [-0.52306908, -0.09234464, 1.17564034, -0.7388621, -0.91333717, -1.50844121],
               [-0.80696738, -0.09417236, 0.75135314, -1.20548987, -1.42038512, -1.86834741],
               [-0.6642682, -0.12166289, 1.04914618, -1.01415539, -1.58841276, -2.54356289]])

stream_np = rp.LieIncrementStream.from_increments(X, width=6, depth=2, dtype=rp.SPReal)
print(stream_np.log_signature(rp.RealInterval(0, 2)))  # { -0.520691(1) -0.145128(2) 1.05716(3) ...}
print(stream_np.log_signature(rp.RealInterval(4, 6)))  #  { -1.47124(1) -0.215835(2) 1.8005(3) ...}

stream_jnp = rp.LieIncrementStream.from_increments(np.array(Y), width=6, depth=2, dtype=rp.SPReal)
print(stream_jnp.log_signature(rp.RealInterval(0, 2)))  # { -1.7583e-06(1) -0.000113874(2) -0.169911(3) ...}
print(stream_jnp.log_signature(rp.RealInterval(4, 6)))  # { -nan(3) -nan(5) -nan([1,3]) -nan([1,5]) -nan([2,3]) ...}

Y = np.array(Y)
stream_jnp = rp.LieIncrementStream.from_increments(Y, width=6, depth=2, dtype=rp.SPReal)
print(stream_jnp.log_signature(rp.RealInterval(0, 2)))  # { -1.7583e-06(1) -0.000113874(2) -0.169911(3) ...}
print(stream_jnp.log_signature(rp.RealInterval(4, 6)))  # { 2(1) 3(3) 5(5) -1([1,5]) -1.5([3,5]) } 
Benjamin-Walker commented 1 year ago

Having reinstalled Jax and Jaxlib, the script in the above comment produces slightly different behaviour:

stream_jnp = rp.LieIncrementStream.from_increments(np.array(Y), width=6, depth=2, dtype=rp.SPReal)
print(stream_jnp.log_signature(rp.RealInterval(0, 2)))  # { -1.7583e-06(1) -0.000113874(2) -0.169911(3) ...}
print(stream_jnp.log_signature(rp.RealInterval(4, 6)))  # { 9.66244e+16(3) -nan(4) -nan(5) -nan([1,4]) ...}

Y = np.array(Y)
stream_jnp_convertb4 = rp.LieIncrementStream.from_increments(Y, width=6, depth=2, dtype=rp.SPReal)
print(stream_jnp_convertb4.log_signature(rp.RealInterval(0, 2)))  # { -1.7583e-06(1) -0.000113874(2) -0.169911(3) ...}
print(stream_jnp_convertb4.log_signature(rp.RealInterval(4, 6)))  # { 9.66244e+16(3) -nan(4) -nan(5) -nan([1,4]) ...}
terrylyons commented 1 year ago

Is someone making assumptions about the internal representations of numpy arrays? Years ago, I came unstuck over data contiguity. Perhaps current versions have changed their assumptions yet again? Just guessing.

Sent from Outlook for iOShttps://aka.ms/o0ukef


From: Ben Walker @.> Sent: Tuesday, August 1, 2023 2:30:37 AM To: datasig-ac-uk/RoughPy @.> Cc: Subscribed @.***> Subject: Re: [datasig-ac-uk/RoughPy] Jax numpy arrays and NaNs in the log-signature. (Issue #25)

Having reinstalled Jax and Jaxlib, the script in the above comment produces slightly different behaviour:

stream_jnp = rp.LieIncrementStream.from_increments(np.array(Y), width=6, depth=2, dtype=rp.SPReal) print(stream_jnp.log_signature(rp.RealInterval(0, 2))) # { -1.7583e-06(1) -0.000113874(2) -0.169911(3) ...} print(stream_jnp.log_signature(rp.RealInterval(4, 6))) # { 9.66244e+16(3) -nan(4) -nan(5) -nan([1,4]) ...}

Y = np.array(Y) stream_jnp_convertb4 = rp.LieIncrementStream.from_increments(Y, width=6, depth=2, dtype=rp.SPReal) print(stream_jnp_convertb4.log_signature(rp.RealInterval(0, 2))) # { -1.7583e-06(1) -0.000113874(2) -0.169911(3) ...} print(stream_jnp_convertb4.log_signature(rp.RealInterval(4, 6))) # { 9.66244e+16(3) -nan(4) -nan(5) -nan([1,4]) ...}

— Reply to this email directly, view it on GitHubhttps://github.com/datasig-ac-uk/RoughPy/issues/25#issuecomment-1658844581, or unsubscribehttps://github.com/notifications/unsubscribe-auth/ABQWN6Z7V2JE6TRXVURNQ2LXS7TT3ANCNFSM6AAAAAA26R6BZE. You are receiving this because you are subscribed to this thread.Message ID: @.***>

inakleinbottle commented 1 year ago

Ok, so I think the problem comes from my implementation of the dlpack protocol. Still looking into it.

inakleinbottle commented 1 year ago

OK, so I've done a couple of things but I don't think I've solved the problem yet.

inakleinbottle commented 1 year ago

Ok so I think I've found the cause of the problem. The function that gets the data type for a dlpack tensor always returns DPReal (as a default implementation). Basically I need to implement this function properly. I'm working on this now.

inakleinbottle commented 1 year ago

Fixed this with PR #29