keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.97k stars 19.46k forks source link

Enormous memory usage after batched forward passes with TensorFlow 2.16.1 (CPU) #19500

Open Dobiasd opened 6 months ago

Dobiasd commented 6 months ago
import numpy as np
import psutil
import tensorflow as tf

model = tf.keras.applications.ResNet152V2()

images = np.zeros([20, 224, 224, 3], dtype=np.uint8)

for run in range(10):
    memory_usage_in_MiB = psutil.Process().memory_info().rss / (1024 * 1024)
    print(f"Memory usage after {run} run(s) (in MiB): {memory_usage_in_MiB:.3f}", flush=True)
    model(images)
Memory usage after 0 run(s) (in MiB): 790.715
Memory usage after 1 run(s) (in MiB): 5969.941
Memory usage after 2 run(s) (in MiB): 7112.559
Memory usage after 3 run(s) (in MiB): 7787.625
Memory usage after 4 run(s) (in MiB): 7827.727
Memory usage after 5 run(s) (in MiB): 6785.316
Memory usage after 6 run(s) (in MiB): 5971.680
Memory usage after 7 run(s) (in MiB): 6615.105
Memory usage after 8 run(s) (in MiB): 7151.496
Memory usage after 9 run(s) (in MiB): 6679.250

(Dockerfile to reproduce)

And it's not just tf.keras.applications.ResNet152V2(). It also happens (for example) with tf.keras.applications.inception_resnet_v2.InceptionResNetV2 and tf.keras.applications.inception_v3.InceptionV3. And it also happens when using Python 3.12.2 instead of 3.11.8.

With TensorFlow 2.15.1 (instead of 2.16.1), however, the memory usage does not explode:

Memory usage after 0 run(s) (in MiB): 1038.891
Memory usage after 1 run(s) (in MiB): 1154.062
Memory usage after 2 run(s) (in MiB): 1154.508
Memory usage after 3 run(s) (in MiB): 1154.684
Memory usage after 4 run(s) (in MiB): 1215.840
Memory usage after 5 run(s) (in MiB): 1154.664
Memory usage after 6 run(s) (in MiB): 1154.816
Memory usage after 7 run(s) (in MiB): 1155.059
Memory usage after 8 run(s) (in MiB): 1154.715
Memory usage after 9 run(s) (in MiB): 1155.367

(Dockerfile to reproduce)

Workaround: Replacing model(images).numpy() with model.predict(images) improves the situation, i.e., it only leaks a little bit.

sirfz commented 6 months ago

try wrapping your model with tf.function, if I recall correctly we recently observed the same issue and this fixed it

fchollet commented 6 months ago

I tried running your snippet 20x and added a call to gc.collect() inside the loop. Here's what I get:

Memory usage after 0 run(s) (in MiB): 950.141
Memory usage after 1 run(s) (in MiB): 1399.016
Memory usage after 2 run(s) (in MiB): 1231.875
Memory usage after 3 run(s) (in MiB): 1204.109
Memory usage after 4 run(s) (in MiB): 1353.109
Memory usage after 5 run(s) (in MiB): 1525.312
Memory usage after 6 run(s) (in MiB): 1807.875
Memory usage after 7 run(s) (in MiB): 1594.766
Memory usage after 8 run(s) (in MiB): 1609.703
Memory usage after 9 run(s) (in MiB): 1556.141
Memory usage after 10 run(s) (in MiB): 1720.438
Memory usage after 11 run(s) (in MiB): 1606.094
Memory usage after 12 run(s) (in MiB): 1803.406
Memory usage after 13 run(s) (in MiB): 1593.234
Memory usage after 14 run(s) (in MiB): 1628.969
Memory usage after 15 run(s) (in MiB): 1665.312
Memory usage after 16 run(s) (in MiB): 1428.234
Memory usage after 17 run(s) (in MiB): 1406.406
Memory usage after 18 run(s) (in MiB): 1117.484
Memory usage after 19 run(s) (in MiB): 1380.219

Memory usage has higher variance than in Keras 2 (and is higher on average) but it is stable within a range (max: 1808, min: 1117, reached after 18 iterations), which indicates that there's no leak. Are you able to run a Python profiler to see what's taking memory?

For good measure, here's what I get when I do the same with tf_keras (Keras 2):

Memory usage after 0 run(s) (in MiB): 1041.750
Memory usage after 1 run(s) (in MiB): 1239.422
Memory usage after 2 run(s) (in MiB): 1075.500
Memory usage after 3 run(s) (in MiB): 1258.969
Memory usage after 4 run(s) (in MiB): 1270.234
Memory usage after 5 run(s) (in MiB): 1271.062
Memory usage after 6 run(s) (in MiB): 1281.203
Memory usage after 7 run(s) (in MiB): 1282.156
Memory usage after 8 run(s) (in MiB): 1284.469
Memory usage after 9 run(s) (in MiB): 1294.281
Memory usage after 10 run(s) (in MiB): 1297.281
Memory usage after 11 run(s) (in MiB): 1299.438
Memory usage after 12 run(s) (in MiB): 1300.125
Memory usage after 13 run(s) (in MiB): 1301.859
Memory usage after 14 run(s) (in MiB): 1305.547
Memory usage after 15 run(s) (in MiB): 1306.891
Memory usage after 16 run(s) (in MiB): 1306.984
Memory usage after 17 run(s) (in MiB): 1314.547
Memory usage after 18 run(s) (in MiB): 1314.875
Memory usage after 19 run(s) (in MiB): 1324.062

Albeit the variance is much lower and the average is lower, this one does look leaky in the sense that it's monotonously increasing.

Dobiasd commented 6 months ago

@sirfz Replacing model = tf.keras.applications.ResNet152V2() with model = tf.function(tf.keras.applications.ResNet152V2()) indeed works. The memory usage stays low in this minimal example. Thanks for this workaround! :+1:

Dobiasd commented 6 months ago

@fchollet Thanks for checking!

I tried to reproduce your test, but failed so far, i.e., the memory usage is still high, even directly after gc.collect():

import gc

import numpy as np
import psutil
import tensorflow as tf

model = tf.keras.applications.ResNet152V2()
images = np.zeros([20, 224, 224, 3], dtype=np.uint8)
for run in range(10):
    memory_usage_in_MiB = psutil.Process().memory_info().rss / (1024 * 1024)
    print(f"Memory usage after {run} run(s) before gc.collect() (in MiB): {memory_usage_in_MiB:.3f}", flush=True)
    gc.collect()
    memory_usage_in_MiB = psutil.Process().memory_info().rss / (1024 * 1024)
    print(f"Memory usage after {run} run(s) after gc.collect() (in MiB): {memory_usage_in_MiB:.3f}", flush=True)
    model(images)
Memory usage after 0 run(s) before gc.collect() (in MiB): 792.438
Memory usage after 0 run(s) after gc.collect() (in MiB): 792.438
Memory usage after 1 run(s) before gc.collect() (in MiB): 5983.020
Memory usage after 1 run(s) after gc.collect() (in MiB): 5983.020
Memory usage after 2 run(s) before gc.collect() (in MiB): 6978.793
Memory usage after 2 run(s) after gc.collect() (in MiB): 6978.793
Memory usage after 3 run(s) before gc.collect() (in MiB): 7011.441
Memory usage after 3 run(s) after gc.collect() (in MiB): 7011.441
Memory usage after 4 run(s) before gc.collect() (in MiB): 7213.758
Memory usage after 4 run(s) after gc.collect() (in MiB): 7213.758
Memory usage after 5 run(s) before gc.collect() (in MiB): 6951.520
Memory usage after 5 run(s) after gc.collect() (in MiB): 6951.520
Memory usage after 6 run(s) before gc.collect() (in MiB): 6536.066
Memory usage after 6 run(s) after gc.collect() (in MiB): 6536.066
Memory usage after 7 run(s) before gc.collect() (in MiB): 5985.203
Memory usage after 7 run(s) after gc.collect() (in MiB): 5985.203
Memory usage after 8 run(s) before gc.collect() (in MiB): 6931.805
Memory usage after 8 run(s) after gc.collect() (in MiB): 6931.805
Memory usage after 9 run(s) before gc.collect() (in MiB): 7641.566
Memory usage after 9 run(s) after gc.collect() (in MiB): 7641.566

(Dockerfile to reproduce)

Are you able to run a Python profiler to see what's taking memory?

Sorry, currently no. Are you?

james77777778 commented 6 months ago

I suspect that this is more likely an issue on tensorflow or docker environment side.

import psutil

import keras
import keras.applications.resnet_v2

model = keras.applications.resnet_v2.ResNet152V2()

images = keras.ops.zeros([1, 224, 224, 3], dtype="uint8")

for run in range(100):
    model(images)
    memory_usage_in_MiB = psutil.Process().memory_info().rss / (1024 * 1024)
    print(
        f"Memory usage after {run} run(s) (in MiB): {memory_usage_in_MiB:.3f}",
        flush=True,
    )
I run the above script for all backends and here are the numbers: Backend Memory Usage Range (in MiB) No Growing after Certain Runs
jax 1210.980~1210.980 V
numpy 1635.582~1644.410 V
tensorflow 1589.508~1617.383 X (leaked!)
torch 867.094~867.344 V
tensorflow (with tf.function) 1645.301~1645.426 V

My environment:

Dobiasd commented 6 months ago

CUDA 12, CUDNN 8.9

Oh, I ran my tests without any GPU. It's all CPU only. I've just expanded the issue title accordingly.

I suspect that this is more likely an issue on tensorflow or docker environment side.

In the TensorFlow repo, I've been told to open the issue here. :grin:

Regarding Docker: The memory problem happens for me not only in Docker, but also when I run on bare metal.

fchollet commented 6 months ago

Thanks for the detailed analysis. The lack of the issue with other eager backends, and the disappearance of the issue when using a tf.function, strongly indicate that the leak may be at the level of the TF eager runtime. It is also likely system dependent, since I can't observe it on my system, nor on Colab (I tried both with TF 2.15 and TF 2.16 with the latest Keras, and while the memory usage differs across the 2 TF versions, there isn't a leak either way).

This isn't the first time we've seen memory leaks with the TF runtime (eager or graph).

github-actions[bot] commented 6 months ago

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

Dobiasd commented 6 months ago

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs.

If I'm not mistaken, the issue is not solved yet.

Or should we close it, because work continues in the corresponding issue in the TensorFlow repo?

cantonios commented 5 months ago

I added more comments on the associated TF bug, but mentioning this here:

I'm pretty sure it is a keras "issue" (though may not be a bug). We get similarly large memory usage in keras using the JAX backend (about 5GB), and it seems to be related to Trackable allocations used for training. We don't see the increased memory usage if only using model.predict(...).

These Trackable allocations are not present when using legacy keras and TF 2.16 - there memory usage remains about 1GB. So there's something about Keras 3 using a lot more memory during regular model(...) calls.