normal-computing / thermox

Exact OU processes with JAX
Apache License 2.0
31 stars 6 forks source link

sample function not handling Processed input matrices #22

Closed denismelanson closed 4 months ago

denismelanson commented 4 months ago

The doc string for "thermox.sampler.sample" says that the input matrices A and D can be either a JAX Array or a Processed[Drift/Diffusion]Matrix instance. However, the function breaks when Processed instances are provided as input.

Here is the code to reproduce the error:

import thermox
from thermox.utils import preprocess
import jax
import jax.numpy as jnp

key = jax.random.PRNGKey(0)

dt = 0.01
ts = jnp.arange(0, 1, dt)

A = jnp.array([[2.0, 0.5, 0.0, 0.0, 0.0],
               [0.5, 2.0, 0.5, 0.0, 0.0],
               [0.0, 0.5, 2.0, 0.5, 0.0],
               [0.0, 0.0, 0.5, 2.0, 0.5],
               [0.0, 0.0, 0.0, 0.5, 2.0]])

b, x0 = jnp.zeros(5), jnp.zeros(5) # Zero drift displacement vector and initial state

D = jnp.array([[2, 1, 0, 0, 0],
               [1, 2, 0, 0, 0],
               [0, 0, 2, 0, 0],
               [0, 0, 0, 2, 0],
               [0, 0, 0, 0, 2]])

A_star, D_star = preprocess(A, D)

samples = thermox.sample(key, ts, x0, A_star, b, D_star)

I get the following traceback: "Traceback (most recent call last): File "/workspaces/thermox/testing.py", line 36, in samples = thermox.sample(key, ts, x0, A_star, b, D_star) File "/workspaces/thermox/thermox/sampler.py", line 109, in sample assert isinstance(A_y, ProcessedDriftMatrix) UnboundLocalError: local variable 'A_y' referenced before assignment".

This does not seem like the intended behavior.

denismelanson commented 4 months ago

Looking at the sample function source code, it's clear that this case is not handled.

if isinstance(A, Array) and isinstance(D, Array):
        A_y, D = preprocess(A, D)

assert isinstance(A_y, ProcessedDriftMatrix)
assert isinstance(D, ProcessedDiffusionMatrix)

I suggest changing the above code to be the same as how the log_prob function handles this:

if isinstance(A, Array) or isinstance(D, Array):
        if isinstance(A, ProcessedDriftMatrix):
            A = A.val
        if isinstance(D, ProcessedDiffusionMatrix):
            D = D.val
        A_y, D = preprocess(A, D)
KaelanDt commented 4 months ago

Nice find! Indeed your solution seems to work. Want to make a PR?