Closed Benjamin-Walker closed 1 year ago
Having reinstalled Jax and Jaxlib, the script in the above comment produces slightly different behaviour:
stream_jnp = rp.LieIncrementStream.from_increments(np.array(Y), width=6, depth=2, dtype=rp.SPReal)
print(stream_jnp.log_signature(rp.RealInterval(0, 2))) # { -1.7583e-06(1) -0.000113874(2) -0.169911(3) ...}
print(stream_jnp.log_signature(rp.RealInterval(4, 6))) # { 9.66244e+16(3) -nan(4) -nan(5) -nan([1,4]) ...}
Y = np.array(Y)
stream_jnp_convertb4 = rp.LieIncrementStream.from_increments(Y, width=6, depth=2, dtype=rp.SPReal)
print(stream_jnp_convertb4.log_signature(rp.RealInterval(0, 2))) # { -1.7583e-06(1) -0.000113874(2) -0.169911(3) ...}
print(stream_jnp_convertb4.log_signature(rp.RealInterval(4, 6))) # { 9.66244e+16(3) -nan(4) -nan(5) -nan([1,4]) ...}
Is someone making assumptions about the internal representations of numpy arrays? Years ago, I came unstuck over data contiguity. Perhaps current versions have changed their assumptions yet again? Just guessing.
Sent from Outlook for iOShttps://aka.ms/o0ukef
From: Ben Walker @.> Sent: Tuesday, August 1, 2023 2:30:37 AM To: datasig-ac-uk/RoughPy @.> Cc: Subscribed @.***> Subject: Re: [datasig-ac-uk/RoughPy] Jax numpy arrays and NaNs in the log-signature. (Issue #25)
Having reinstalled Jax and Jaxlib, the script in the above comment produces slightly different behaviour:
stream_jnp = rp.LieIncrementStream.from_increments(np.array(Y), width=6, depth=2, dtype=rp.SPReal) print(stream_jnp.log_signature(rp.RealInterval(0, 2))) # { -1.7583e-06(1) -0.000113874(2) -0.169911(3) ...} print(stream_jnp.log_signature(rp.RealInterval(4, 6))) # { 9.66244e+16(3) -nan(4) -nan(5) -nan([1,4]) ...}
Y = np.array(Y) stream_jnp_convertb4 = rp.LieIncrementStream.from_increments(Y, width=6, depth=2, dtype=rp.SPReal) print(stream_jnp_convertb4.log_signature(rp.RealInterval(0, 2))) # { -1.7583e-06(1) -0.000113874(2) -0.169911(3) ...} print(stream_jnp_convertb4.log_signature(rp.RealInterval(4, 6))) # { 9.66244e+16(3) -nan(4) -nan(5) -nan([1,4]) ...}
— Reply to this email directly, view it on GitHubhttps://github.com/datasig-ac-uk/RoughPy/issues/25#issuecomment-1658844581, or unsubscribehttps://github.com/notifications/unsubscribe-auth/ABQWN6Z7V2JE6TRXVURNQ2LXS7TT3ANCNFSM6AAAAAA26R6BZE. You are receiving this because you are subscribed to this thread.Message ID: @.***>
Ok, so I think the problem comes from my implementation of the dlpack protocol. Still looking into it.
OK, so I've done a couple of things but I don't think I've solved the problem yet.
Ok so I think I've found the cause of the problem. The function that gets the data type for a dlpack tensor always returns DPReal (as a default implementation). Basically I need to implement this function properly. I'm working on this now.
Fixed this with PR #29
The log-signature of a path depends significantly on whether that array was first a jax numpy array that is converted to a numpy array or if the array was originally a numpy array. This can lead to NaNs being present in the log_signature. There is also an element of randomness in this behaviour, with the same script demonstrating different behaviour on subsequent runs, with the nans sometimes not being present. However, II have never observed agreement between the three examples below, which I was expecting to all return (roughly) the same output.
I believe this may actually be an issue with how Jax handles numpy conversion. I apologise if this turns out to be a Jax problem instead of a roughpy problem.
The following script demonstrates this behaviour.