google-research / torchsde

Differentiable SDE solvers with GPU support and efficient sensitivity analysis.
Apache License 2.0
1.51k stars 194 forks source link

Fix notebook. #74

Closed lxuechen closed 3 years ago

lxuechen commented 3 years ago

Updates on BI and added example for Stratonovich SDEs.

patrick-kidger commented 3 years ago

I'm tempted to suggest that backprop should be included in this notebook. (Or at least a notebook.) It's pretty important for a lot of use cases, and it has a couple caveats worth knowing about (i.e. prefer strat over ito; adjoint vs non adjoint). I think the latent SDE example is rather involved / tricky to follow so it'd be nice to have a simple example of this behaviour.

lxuechen commented 3 years ago
  • Cell 1:

    • Can the SDE definition switch to dy(t) = f(t, y(t), theta) ...? In particular calling the state y and ordering the arguments t, y, for consistency with the code.

Done.

  • Cell 2:

    • For the simple SDE, can we avoid using torch.nn.Parameter. I've found that a surprising number of people have no idea what that is / that it exists, and are really only familiar with torch.nn.Linear etc. If you like (not a strong preference) the sample SDE could be a simple neural SDE with Linear or single-hidden-layer-MLP vector fields, which would help tie it into familiar territory for most people?
    • For a notebook giving a first introduction, I'd be inclined to use general noise for consistency with the mathematics just presented.

I don't see this as a big problem. The name Parameter itself should be quite sufficient in explaining what's happening there, IMHO. The reason of using diagonal is that we may then use more sophisticated solvers. A demonstration using only the Euler solver seems quite bland. Happy to discuss more if you still think otherwise.

  • Cell 3:

    • d isn't used. I assume it's meant in y0. Additionally, could it be renamed to state_size? This is consistent with the term batch_size, and with DOCUMENTATION.md.
    • can T be specified on a separate line to batch_size, d, as it's not a dimension size.
    • Can y0 = torch.zeros(batch_size, 1).fill_(0.1) be replaced with torch.full?
    • For the discussion on the different solvers - can there be a comment saying that setting the method to None, the default, will mean it use an appropriate default method. (Once my method fixes go in, anyway.)

Good points! Done for d. T is actually a dimension: it's the output dimension of the resulting tensor. Done for torch.full. Done for method=None.

  • Cell 5:

    • I think names is reasonably niche / not-that-useful functionality, and I'd probably remove it from a first introduction. If you do want to keep it though, then I wouldn't use h but that conflicts with the prior drift.
    • Can the repeated plotting code be factored out to a function?

This actually isn't that niche IMHO. In latent SDEs it's used for plotting the prior samples, since there're actually two SDEs there sharing a part with each other. I was also imagining people only training the drift with the most standard forward method of modules.

Factoring out the plotting might make sense, though I think it may complicate the demonstration, as then we would need to make the plotting function take in xlabel, ylabel, title, and maybe other parameters in the future. Not that we don't have one in the codebase (swiss_knife_plotter in one of the utility files), it might just complicate the demo.

  • Cell 7:

    • BPath and BTree are used in the text. (BInterval is used in the code.)
    • BInterval is raising warnings. Explicitly passing the start point as the first query point will disable those. (Generally I regard Brownian motion as defined on intervals, not points, anyway, so I feel like this is also encouraging good practice thinking.)
    • Can the repeated plotting code be factored out to a function?

Sorry I overlooked this. Fixed.

  • Cell 8:

    • I wouldn't specify the levy area approximation here, it's an additional complexity.
    • It's not quite accurate to say that this gives deterministic behaviour: it's only doing that because it's using a fixed solver.
    • These kinds of concerns (levy area; speed/memory trade-offs; determinism) perhaps warrant their own notebook, instead of being part of this one? They're a bit more advanced, and the defaults generally do the right thing for most use cases.

I'm tempted to suggest that backprop should be included in this notebook. (Or at least a notebook.) It's pretty important for a lot of use cases, and it has a couple caveats worth knowing about (i.e. prefer strat over ito; adjoint vs non adjoint). I think the latent SDE example is rather involved / tricky to follow so it'd be nice to have a simple example of this behaviour.

So I thought deeply about including grad computation initially. The reason of not including it is that just doing squared loss minimization on some observations would basically end up learning an ODE -- it wouldn't really motivate SDEs that well. I obviously can't complain if you have better ideas on creating a simple example w/o getting into complicated models though.

patrick-kidger commented 3 years ago

I don't see this as a big problem. The name Parameter itself should be quite sufficient in explaining what's happening there, IMHO. The reason of using diagonal is that we may then use more sophisticated solvers. A demonstration using only the Euler solver seems quite bland. Happy to discuss more if you still think otherwise.

The use of Parameter is just a small extra hurdle for many, is all. I think a reasonable fraction of our users may not be PyTorch experts -- they're here for the SDEs. I think using general is more important than using SRK; additionally Euler is a more familiar point of reference than SRK.

That said, I think you're right that diagonal/SRK are worth mentioning. I think I'd remove the BInterval stuff in favour of expanding the discussion on noise types. (Including a discussion on available solvers.)

Good points! Done for d. T is actually a dimension: it's the output dimension of the resulting tensor.

Ah, right - I missed that. I've more usually seen T used as the terminal time rather than the number of points (which I think I see written more often N in mathematics). I'd change that to t_size. (At least that's what I call it in the documentation. Generally some *_size makes sense as a name.)

This actually isn't that niche IMHO. In latent SDEs it's used for plotting the prior samples, since there're actually two SDEs there sharing a part with each other. I was also imagining people only training the drift with the most standard forward method of modules.

Fair enough. Do you think it's worth having some default zero-diffusion that makes only training the drift easy? (Not saying I do, just thinking out loud.)

Factoring out the plotting might make sense, though I think it may complicate the demonstration, as then we would need to make the plotting function take in xlabel, ylabel, title, and maybe other parameters in the future. Not that we don't have one in the codebase (swiss_knife_plotter in one of the utility files), it might just complicate the demo.

I think it simplifies the demonstration. Its details don't matter from a pedagogical standpoint, and as long as its well named (plot) then it's clear that's it's not relevant. (I wouldn't put it in one of the utility files, so as to keep things self contained.)

So I thought deeply about including grad computation initially. The reason of not including it is that just doing squared loss minimization on some observations would basically end up learning an ODE -- it wouldn't really motivate SDEs that well. I obviously can't complain if you have better ideas on creating a simple example w/o getting into complicated models though.

I'm just suggesting some very straightforward:

ys = sdeint(....)
y = ys[-1]
y.backward(torch.randn_like(y))

Just to make the point that it's possible.

If you do want a reasonable reason for doing it then it can be something like:

ys = sdeint(...)
y = ys[-1]
value = f(y)
loss = (target - value.mean(dim=0)).pow(2).mean().sqrt()
loss.backward()

which is one of the standard uses of SDE models in the mathematical literature, as a way to model terminal densities satisfying certain (f) statistics. In modern ML terms it looks very MMD-like.

lxuechen commented 3 years ago

All done except:

  1. I don't see using nn.Parameter as a hurdle. Rather, I see as putting there a big neural net as complicating the demo.
  2. I think diagonal is more intuitive than general, since the function output looks more like the ODE case.

Fair enough. Do you think it's worth having some default zero-diffusion that makes only training the drift easy? (Not saying I do, just thinking out loud.)

I don't fully understand what you mean by "zero-diffusion". If this just means a diffusion function that always outputs zeros, then the SDE is essentially an ODE. If this mean a constant diffusion function, i.e. one without any trainable parameters, then I agree there's possibly something interesting here we could try.

I agree with having backward, factoring out plotting, and using t_size.

patrick-kidger commented 3 years ago

Haha alright, let's agree to disagree on the parameter/noise.

Zero diffusion - you described people using the forward method for the drift; this only seems natural to me if the diffusion isn't present, with the only natural default being zero. (Which as you say, reduces to the ODE case.) I'm not sure I really see a need for this, I was just positing what I thought was the natural scenario for what you were describing.

patrick-kidger commented 3 years ago

Cell 1:

Cell 2:

Cell 3:

Cell 4:

Cell 9:

patrick-kidger commented 3 years ago

(Overall despite my criticism I like it btw.)

lxuechen commented 3 years ago

All done! Genuinely thankful for the thorough comments!

patrick-kidger commented 3 years ago

LGTM!