tensorflow / tensorboard

TensorFlow's Visualization Toolkit
Apache License 2.0
6.7k stars 1.66k forks source link

Graph size balloons over consecutive Keras models #2093

Open wchargin opened 5 years ago

wchargin commented 5 years ago

Running many unrelated Keras models in one Python script leads to steadily increasing event file sizes, even when the models and callbacks are unrelated to each other.

I noticed this because I ran 200 runs overnight, and wondered why the event file for the first one was 220KB while the event file for the last run was 20MB.

Run the following script in tf-nightly-2.0-preview in Python 3 (in a directory where you don’t mind the logs directory being erased):

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import shutil

import numpy as np
import tensorflow as tf

LOGDIR = "logs"  # will be erased

INPUT_SHAPE = (10,)
OUTPUT_CLASSES = 10

def model_fn():
  model = tf.keras.models.Sequential([
      tf.keras.layers.Input(INPUT_SHAPE),
      tf.keras.layers.Dense(64, activation="relu"),
      tf.keras.layers.Dense(OUTPUT_CLASSES, activation="softmax"),
  ])
  model.compile(optimizer="adam", loss="sparse_categorical_crossentropy")
  return model

def make_data(m):
  x = np.random.uniform(size=(m,) + INPUT_SHAPE)
  y = np.random.randint(0, OUTPUT_CLASSES, size=(m,))
  return (x, y)

def main():
  np.random.seed(0)
  (x_train, y_train) = make_data(1000)
  (x_test, y_test) = make_data(100)
  shutil.rmtree(LOGDIR, ignore_errors=True)
  for i in range(10):
    model = model_fn()
    callback = tf.keras.callbacks.TensorBoard(
        os.path.join(LOGDIR, str(i)),
        update_freq=600,
        profile_batch=0,
    )
    model.fit(x=x_train, y=y_train, callbacks=[callback])

if __name__ == "__main__":
  main()

Then run wc -c logs/**/train/*:

$ wc -c logs/**/train/*
  53584 logs/0/train/events.out.tfevents.1554484297.HOSTNAME.131202.562.v2
 106858 logs/1/train/events.out.tfevents.1554484297.HOSTNAME.131202.1350.v2
 160140 logs/2/train/events.out.tfevents.1554484298.HOSTNAME.131202.2131.v2
 213422 logs/3/train/events.out.tfevents.1554484298.HOSTNAME.131202.2912.v2
 266704 logs/4/train/events.out.tfevents.1554484298.HOSTNAME.131202.3693.v2
 320258 logs/5/train/events.out.tfevents.1554484299.HOSTNAME.131202.4474.v2
 373810 logs/6/train/events.out.tfevents.1554484299.HOSTNAME.131202.5255.v2
 427362 logs/7/train/events.out.tfevents.1554484300.HOSTNAME.131202.6036.v2
 480914 logs/8/train/events.out.tfevents.1554484300.HOSTNAME.131202.6817.v2
 534469 logs/9/train/events.out.tfevents.1554484300.HOSTNAME.131202.7598.v2
2937521 total

In each case, the graph makes up the vast majority of the event file (all but a kilobyte or so):

RUN     SIZE OF GRAPH
---     -------------
0         52367 bytes
1        105637 bytes
2        158919 bytes
3        212201 bytes
4        265483 bytes
5        319035 bytes
6        372587 bytes
7        426139 bytes
8        479691 bytes
9        533246 bytes
wchargin commented 5 years ago

cc @stephanwlee

wchargin commented 5 years ago

I’ve triaged this as high priority because it makes logdirs quadratically sized. I have a logdir that should be 50MB but is 2GB.

stephanwlee commented 5 years ago

Few things I noticed:

  1. This is reproducible in tf-nightly
  2. TF v1: Keras uses default graph in ops for defining graph. All model calls are basically adding onto the same graph.
  3. TF v2: Keras instantiates a FuncGraph('keras_graph') but I do lack knowledge in how this is actually used. Specifically, how is the nodes in this graph added when instantiating a Keras layer? In any case, this should have the same problem as (2) as it is sharing the same FuncGraph.
  4. Using TF v1, I added a line to reset the default graph (tf.compat.v1.reset_default_graph()) before calling the model func and it resulted in below:
    run 0:      75408 bytes
    run 1:      75408 bytes
    run 2:      75408 bytes
    run 3:      75408 bytes
    run 4:      75408 bytes
    run 5:      75408 bytes
    run 6:      75408 bytes
    run 7:      75408 bytes
    run 8:      75408 bytes
    run 9:      75408 bytes

    Funnily this has implications to the execution time too :)

    
    # BEFORE
    1000/1000 [==============================] - 0s 191us/sample - loss: 2.3172
    1000/1000 [==============================] - 0s 154us/sample - loss: 2.3278
    1000/1000 [==============================] - 0s 168us/sample - loss: 2.3309
    1000/1000 [==============================] - 0s 174us/sample - loss: 2.3373
    1000/1000 [==============================] - 0s 207us/sample - loss: 2.3350
    1000/1000 [==============================] - 0s 235us/sample - loss: 2.3242
    1000/1000 [==============================] - 0s 237us/sample - loss: 2.3371
    1000/1000 [==============================] - 0s 247us/sample - loss: 2.3092
    1000/1000 [==============================] - 0s 251us/sample - loss: 2.3562
    1000/1000 [==============================] - 0s 268us/sample - loss: 2.3335

AFTER

1000/1000 [==============================] - 0s 207us/sample - loss: 2.3126 1000/1000 [==============================] - 0s 141us/sample - loss: 2.3214 1000/1000 [==============================] - 0s 138us/sample - loss: 2.3141 1000/1000 [==============================] - 0s 154us/sample - loss: 2.3143 1000/1000 [==============================] - 0s 140us/sample - loss: 2.3242 1000/1000 [==============================] - 0s 147us/sample - loss: 2.3290 1000/1000 [==============================] - 0s 144us/sample - loss: 2.3246 1000/1000 [==============================] - 0s 134us/sample - loss: 2.3250 1000/1000 [==============================] - 0s 148us/sample - loss: 2.3380 1000/1000 [==============================] - 0s 137us/sample - loss: 2.3268


WARNING: this is not a replacement for benchmark but I think I see the trend :)

@omalleyt12  is my assessment correct? Also, can you shine some light on (3)? Thanks!
manivaradarajan commented 5 years ago

Taylor / Tom, can you take a look and update?

robieta commented 5 years ago

You are correct; all keras models share a single graph. (Unless called under an explicit graph scope, in which case they will use that instead.) The primary reason for this is that you can mix and match models, so the ops have to live on the same graph. However this does indeed introduce issues as you accrue more and more orphaned stuff. For now the solution is to call tf.keras.backend.clear_session. (Despite the name, it works in 1.x and 2.0) This just nukes everything and starts over.

I am currently working on a prototype to break up the keras global graph. One of the key motivating factors is our current inability to garbage collect without clearing everything. I don't have an estimate for the timeline, but I'll keep you appraised.