jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.59k stars 2.82k forks source link

Possible leak in random number generation #25069

Open RadostW opened 3 days ago

RadostW commented 3 days ago

Description

When using jax based package pychastic (an SDE solver) jax backend keeps eating memory indefinitely.

import jax
import jax.numpy as np
import numpy.random 
import psutil
import pychastic

def main():

    for _ in range(100):

        initial_samples = numpy.random.random((10000, 3))
        def drift_fn(X):
            return -X

        def noise_fn(_):
            return np.eye(3)

        solver = pychastic.sde_solver.SDESolver(dt=0.01)
        problem = pychastic.sde_problem.SDEProblem(
        a=drift_fn, b=noise_fn,
        x0=initial_samples,
        tmax=10)

        solver.solve_many(problem, n_trajectories=None, progress_bar=True)

        mem_usage_now = psutil.Process().memory_info().rss / 1024 ** 2

        # jax.clear_backends()

        del solver
        del problem
        del initial_samples
        del noise_fn
        del drift_fn

        print(f'Mem usage: {mem_usage_now} MB')

if __name__ == "__main__":
    main()

Output (abbreviated)

100%|████████████████████████████████| 1000/1000 [00:01<00:00, 716.20it/s]
Mem usage: 161.609375 MB
100%|████████████████████████████████| 1000/1000 [00:01<00:00, 869.53it/s]
Mem usage: 165.42578125 MB
100%|████████████████████████████████| 1000/1000 [00:01<00:00, 893.27it/s]
Mem usage: 168.7578125 MB
100%|████████████████████████████████| 1000/1000 [00:01<00:00, 885.05it/s]
Mem usage: 171.8359375 MB
100%|████████████████████████████████| 1000/1000 [00:01<00:00, 898.58it/s]
Mem usage: 175.03125 MB
100%|████████████████████████████████| 1000/1000 [00:01<00:00, 864.00it/s]
Mem usage: 178.3359375 MB
100%|████████████████████████████████| 1000/1000 [00:01<00:00, 863.09it/s]
Mem usage: 181.51171875 MB

I apologize for the contrived code to reproduce the issue. I'd be happy to chase the leak further, but I'm unfamiliar with any tools that could help diagnose the issue. Is there some way to see what's taking up all this space?

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.3.25
jaxlib: 0.3.25
numpy:  1.23.5
python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]

also tested (same result) with

jax:    0.4.30
jaxlib: 0.4.30
numpy:  2.1.3
python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
justinjfu commented 9 hours ago

A couple potentially helpful tips: (1) You can use jax.live_arrays() https://jax.readthedocs.io/en/latest/_autosummary/jax.live_arrays.html to return all of the live arrays and check for potential leaks. (2) Does manually running gc.collect() help with the problem?