Open alessiomora opened 1 year ago
@sushreebarsa Yes, I do confirm. As I higlighted in the issue above, with TF v 2.9.2 the issue is not present, so yes the issue is not reproducibile with TF v 2.9.2. Thank you very much for your help
@sampathweb does this appear to have similar symptom comparing to the previous memory leak issues we've seen?
This may be similar to a recent memory leak in evaluation - but just a quick check, if you run 10,000 epochs instead of a loop over Model.fit 10,000 times, do you still see the memory leak?
Hey, Currently, I have to deal with the same problem. The memory leak is caused by tf.data.Dataset.
Hi all, thank you for your help. Is there a solution to this behaviour?
Unfortunately, I have not found a solution yet. In my case, I use a workaround based on a batch script because when the Python program terminates, all memory is released. So instead of using the for-loop in Python, you can write a for-loop in a batch script, which calls the script containing your fit method. (You just need to find out what the max number of iterations before the leak crashes your program)
Hi @sushreebarsa, thank you for your help. Any news on the issue? I do believe that the memory leak is caused by model.fit().
Thanks.
Hey @alessiomora, I had a similar issue. I wanted to call model.fit() inside the loop because my dataset was too large. What worked for me, was to do a little bit of cleanup with Python del, gc.collect() combined with tf.keras.backend.clear_session().
I am using TensorFlow 2.12 in Windows WSL2 like recommended in: https://www.tensorflow.org/install/pip
My code:
# Load your model just once before loop
for i in range(1, 1001):
# Get part of your dataset
input_set, output_set = DataSet.get_training_data()
# Train your model with custom batch_size, num_epochs
history = model.fit(input_set, output_set, validation_data=(input_set, output_set), epochs=num_epochs, batch_size=batch_size)
# Clean memory after use
del history
del input_set
del output_set
tf.keras.backend.clear_session()
gc.collect() #garbage collector collect
# Save once in a while
if (i % 100 == 0):
model.save(f"./checkpoints/AI_checkpoint_{i//60}.h5")
before I added a clean memory section, a program used to make two or three iterations and then crash because memory was full (sometimes VRAM, sometimes GPU memory). This worked for me, hope it will work for you too.
Hi @Metcoler, thank you for the suggestion. However, the memory seems to still stadealy increase, and eventually OOM appears. I am sure there is a problem in the .fit() implementation as reported in other Github issues and in other stackoverflow questions.
This may be related to:
I'm facing similar issues with the GCP pre-built container image with TF 2.12 GPU (europe-docker.pkg.dev/vertex-ai/training/tf-gpu.2-12.py310:latest), in a system with 32 vCPUs, 208GB of RAM and 4 NVIDIA TESLA V100s, this is the chart of RAM usage
The spikes are the moments in which the validation is performed, my data pipeline consists of loading multiple TFrecords with images and labels, records size goes from (300mb to max 1.8GB)
Unfortunately I cannot disclosure my full code, but this is the order of my tf.data.Dataset operations
dataset = tf.data.Dataset.list_files(tfrecord_files, shuffle=True)
# cycle length here is 100 for the training dataset, None for validation, parsed examples loads the TFecords and parse the content
dataset = dataset.interleave(parser.get_parsed_examples(fn), cycle_length=cycle_length, num_parallel_calls=AUTOTUNE)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size, drop_reminder=True)
# transformation function that normalizes the data
dataset = dataset.map(transform_func, num_parallel_calls=AUTOTUNE)
dataset = dataset.prefetch(AUTOTUNE)
Then on model.fit
(which I call 1 time only, not in a loop), you can see on the RAM chart the training intervals being memory efficient (the dataset is huge and the memory is constant and drops at the end of consumption), however there are validation spikes that increase exponentially in memory, any hints or ideas of what might be happening?
I already tried cleaning up with gc.collect()
memory after the end of the validation with a callback on_test_end
thank you for your help!
I switched to tensorflow JS to Python, this made me switch back to Javascript 😅
Is there an update on this? This problem still appears in tf 2.14 and has been reported many times.
Any update on this?
any updates?
tf.data.Dataset does not seem the source of the leak; this code has no issues:
import tensorflow as tf
d=128*4
for r in range(0, 10000):
ds = tf.data.Dataset.from_tensor_slices((tf.random.uniform((d, 1000)), tf.ones((d))))
ds.batch(64)
It rather it seems that the way OP wrote the code it's regenerating new weights without garbage collecting it. The memory increases by a multiple of the number of weights on each epoch.
As a another proof, one can train using:
x = np.random.standard_normal((64,1000))
y = np.ones((64,))
and get the same output.
The same happens for Sequential; the number of weights and everything is the same as expected, and in both cases there is a memory leak.
Apparently, it's a Tensorflow issue (see link just below this comment.)
System information.
Describe the problem. Memory usage steadily increases when using tf.Model and tf.Model.fit() in a loop, and leads to Out Of Memory exception saturating the memory eventually. clear_session() does not help. The same code with TF version == 2.9.2 has an almost constant memory usage instead, and works as expected.
Describe the problem clearly here. Be sure to convey here why it's a bug in Keras or why the requested feature is needed.
Describe the current behavior. Memory usage steadily increases when using tf.Model and tf.Model.fit() in a loop, and leads to Out Of Memory exception saturating the memory eventually.
Describe the expected behavior. The memory usage remains almost the same.
Standalone code to reproduce the issue.
Provide a reproducible test case that is the bare minimum necessary to generate the problem. If possible, please share a link to Colab/Jupyter/any notebook.
Source code / logs.
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached. Try to provide a reproducible test case that is the bare minimum necessary to generate the problem.