DiffEqML / torchdyn

A PyTorch library entirely dedicated to neural differential equations, implicit models and related numerical methods
https://torchdyn.org
Apache License 2.0
1.36k stars 125 forks source link

More clarification regarding `s_span` in CNF model tutorials #69

Closed lucaslie closed 2 years ago

lucaslie commented 3 years ago

Describe the bug It's not really a bug but it can kind of easily cause a bug down the line. ;)

In cell 9 and 10 of your CNF tutorials (07a and 07b) you reverse the flow of the CNF so that you can sample from it and plot it by changing s_span of the model:

model[1].s_span = torch.linspace(1, 0, 2)

After that you never revert the flow back to its original state, which could be achieved by running:

model[1].s_span = torch.linspace(0, 1, 2)

I think it would be great to have some additional clarification about that step.

I based some on my own code on these tutorials, where after plotting the network I continued training. Obviously, the resulting network was garbage. :(

Step to Reproduce

Steps to reproduce the behavior:

  1. Run either of the CNF tutorials.
  2. After visualizing the result, repeat the training procedure
  3. Visualize it again and you will see odd results.

Expected behavior

Have some kind of clarification about ensuring that the flow points in the right direction. Maybe there is even a way to abstract away some of the specifics from the user by providing an API to change between both "modes".

Zymrael commented 3 years ago

Hi @lucaslie, thanks for writing this up beautifully. Ours was a choice -- you might have noticed how other normalizing flow implementations hide some details behind sample methods. We decided to show exactly which steps are necessary to sample from a CNF. However, it's probably a good idea to add an additional comment or revert the s_span after plotting by default so that the problem can be avoided in the future. We'd welcome your PR in that case!