google-research / torchsde

Differentiable SDE solvers with GPU support and efficient sensitivity analysis.
Apache License 2.0
1.51k stars 194 forks source link

Memory usage #89

Closed rubick1896 closed 3 years ago

rubick1896 commented 3 years ago

Hi, I observed that the memory usage increases while training goes on. This is unusual in standard NN, since the size of the parameters is fixed. Is this normal in this model or did I get something wrong? I also noticed there is a cache_size parameter. What is the unit of this integer, is it in MB?

thanks

patrick-kidger commented 3 years ago

That's not normal behaviour. Can you provide a MWE demonstrating this? [1] This usually occurs when you have a mistake in your training loop, nothing to do with torchsde.

The cache_size parameter doesn't have units like that -- it specifies the size of the LRU cache used in the Brownian Interval, which is measured in number of samples. That's pretty technical; you should almost certainly leave it alone. (If you're trying to reduce memory consumption then I'd look elsewhere -- the memory consumed by the Brownian Interval is generally negligible.)

[1] I'd note that if you're using adaptive step sizes without the adjoint method then this can plausibly occur because of the complexity changing during training -- typically NFEs go up over training -- but that's usually not substantial.

lxuechen commented 3 years ago

Hi,

I see the chances of there being some subtle bug in torchsde regarding memory to be quite low. Though, I don't want to totally rule out that possibility. If you have a piece of minimum reproducible code, I'd also be happy to take a look and see if I can help solve the problem. It's very hard to reason about things without context.

rubick1896 commented 3 years ago

Thanks for the reply. It takes some time to work for a minimum reproducible code since the model also relies on some framework in the middle and I don't control the training loop directly. I want to provide more context here in words.

I am trying to model a rolling forecasting problem. Given x1,x2...xt, predict xt+1, and at the next time step, now you have x1,x2...xt+1, and predict xt+2. The training strategy is to find a random split point t, look back t time steps to get a single training case, and use t+1 as the label. So unlike standard NN traning, in each epoch, the training cases are not the same, they are random samples from some long time series.

If my description is not clear, maybe take a look at the code here. https://github.com/zalandoresearch/pytorch-ts/blob/master/pts/transform/split.py

I am using an RNN to encode a training case and feed that to torchsde.

I just want to make sure that this is not a game-changer and I should still expect the memory usage to be constant between different epochs. If so, I will try to work for a minimum reproducible code.

lxuechen commented 3 years ago

The expectation is that VRAM usage should roughly stay the same across different gradient updates. If you're using adaptive solvers, then yes, it's fairly possible that the learned dynamics become harder to solver, therefore requiring a lot more function evaluations.

Somewhat contrary to what Patrick has suggested, I've seen quite a few cases where NFE at the start of training could be much different than that at the end of training.

It seems your description doesn't provide the setting at the granularity that we could help with. You're mentioning data x1,x2...xt+1 and target xt+2. Does t change over different gradient updates?

I'm closing this issue for now since not enough background is given. Feel free to reopen if you could provide us with more context or at the very least create a small example that points towards a problem.