google-research / torchsde

Differentiable SDE solvers with GPU support and efficient sensitivity analysis.
Apache License 2.0
1.56k stars 196 forks source link

Scalar noise SDE shape check fails #27

Closed lxuechen closed 4 years ago

lxuechen commented 4 years ago

Latest on dev, the contract checking fails for the following

import torch

from torchsde import BrownianInterval, BaseSDE, sdeint

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

batch_size, d, m = 16, 3, 2
t0, t1 = 0.0, 0.3
dtype = torch.get_default_dtype()
y0 = torch.zeros(batch_size, d, device=device)

class ScalarSDE(BaseSDE):

    def f(self, t, y):
        return torch.sin(y)

    def g(self, t, y):
        return torch.cos(y).sigmoid()

sde = ScalarSDE(sde_type='ito', noise_type='scalar')
bm = BrownianInterval(t0=t0, t1=t1, shape=(batch_size, 1), dtype=dtype, device=device, levy_area_approximation='foster')
sdeint(sde=sde, y0=y0, ts=[t0, t1], bm=bm)

Also it might be worthwhile to characterize exactly what API we'd like to use. I think we should go with g outputting (batch_size, d, 1) or (batch_size, d) or preferably both if possible for backward compatibility.

patrick-kidger commented 4 years ago

Yes, this was a concious choice. At the moment it checks for shape (..., d, 1).

This raises an interesting point: I think it's reasonable to support either multiple batch dimensions [and I realise much of the rest of the code doesn't yet do this], or to support scalar diffusion having either of the shapes you mention, but not both. The logic for what counts as batch and what doesn't would end up being rather strange, then.

Personally I would lean quite strongly in favour of supporting multiple batch dimensions. I made the concious decision not to in previous projects, and have since regretted it.