google-research / torchsde

Differentiable SDE solvers with GPU support and efficient sensitivity analysis.
Apache License 2.0
1.52k stars 195 forks source link

Bug in BrownianInterval #25

Closed lxuechen closed 3 years ago

lxuechen commented 3 years ago

This bug relates to my concern that there's virtually no test for BrownianInterval at the moment, and in retrospect it shouldn't have been merged into dev without testing out the aspects I mentioned in #15. Shape tests are basic, though it is helpful at most times. The change of API for the solvers also breaks backward compatibility.

The data structure couldn't produce A with the correct shape

import torch

from torchsde import BrownianInterval

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

batch_size, d, m = 16, 3, 2
t0, t1 = 0.0, 0.3
dtype = torch.get_default_dtype()

bm_general = BrownianInterval(
    t0=t0, t1=t1, shape=(batch_size, m), dtype=dtype, device=device, levy_area_approximation='foster'
)
W, U, A = bm_general(0.0, 0.1)  # Fails.
print(W.size())
print(U.size())
print(A.size())
patrick-kidger commented 3 years ago

Resolved in #28