odlgroup / odl

Operator Discretization Library https://odlgroup.github.io/odl/
Mozilla Public License 2.0
374 stars 105 forks source link

Fix a memory leak in the Operator base class #1632

Closed peterisr closed 9 months ago

peterisr commented 9 months ago

References to operator class definitions were cached in the argument cache of _dispatch_call_args function, hence GC couldn't collect them. Normally that wouldn't be an issue, as class definitions don't consume much memory. However, a typical pattern of defining adjoint operators in ODL is using nested classes with references back to their encapsulating objects. Depending on the Operator, it may cause considerable memory leaks.

For example, the adjoint method of RayTransform objects defines RayBackProjection nested class, which holds a reference back to RayTransform, which in turn references geometry objects, preventing them from being reaped by GC and leading to considerable memory leaks.

Here's a reproducer of the leak:

import gc
import odl

def iter():
    reco_space = odl.uniform_discr(min_pt=[-20, -20], max_pt=[20, 20],
                                   shape=[128, 128], dtype='float32')
    geometry = odl.tomo.parallel_beam_geometry(reco_space, num_angles=100)
    ray_trafo = odl.tomo.RayTransform(reco_space, geometry)
    ray_trafo.adjoint # This line causes a memory leak

def main():
    n = 150
    for i in range(1, 1 + n):
        iter()
        gc.collect()
        print("There are ", len(gc.get_objects()), " python objects")

if __name__ == '__main__':
    main()

Running the code without this fix leaks objects on each iteration.

After iter 1 there are 49945 python objects
After iter 2 there are 50002 python objects
After iter 3 there are 50059 python objects
[..]
After iter 126 there are 57070 python objects
After iter 127 there are 57127 python objects
After iter 128 there are 57127 python objects # further leaking stops here because the lru_cache() [1] used to implement cache_arguments() is limited to 128 objects by default
[..]
After iter 150 there are 57127 python objects

With fix the object count after each iteration remains constant.

After iter 1 there are 49872 python objects
After iter 2 there are 49872 python objects
After iter 3 there are 49872 python objects
[..]
After iter 150 there are 49872 python objects

The fix I've implemented is moving away from an argument cache in favor of caching the returned values on the class itself. This ensures no class object references are held permanently in the cache_arguments() decorator.


[1] https://docs.python.org/3/library/functools.html#functools.lru_cache