Closed denismelanson closed 4 months ago
Looking at the sample function source code, it's clear that this case is not handled.
if isinstance(A, Array) and isinstance(D, Array):
A_y, D = preprocess(A, D)
assert isinstance(A_y, ProcessedDriftMatrix)
assert isinstance(D, ProcessedDiffusionMatrix)
I suggest changing the above code to be the same as how the log_prob function handles this:
if isinstance(A, Array) or isinstance(D, Array):
if isinstance(A, ProcessedDriftMatrix):
A = A.val
if isinstance(D, ProcessedDiffusionMatrix):
D = D.val
A_y, D = preprocess(A, D)
Nice find! Indeed your solution seems to work. Want to make a PR?
The doc string for "thermox.sampler.sample" says that the input matrices A and D can be either a JAX Array or a Processed[Drift/Diffusion]Matrix instance. However, the function breaks when Processed instances are provided as input.
Here is the code to reproduce the error:
I get the following traceback: "Traceback (most recent call last): File "/workspaces/thermox/testing.py", line 36, in
samples = thermox.sample(key, ts, x0, A_star, b, D_star)
File "/workspaces/thermox/thermox/sampler.py", line 109, in sample
assert isinstance(A_y, ProcessedDriftMatrix)
UnboundLocalError: local variable 'A_y' referenced before assignment".
This does not seem like the intended behavior.