araffin / learning-to-drive-in-5-minutes

Implementation of reinforcement learning approach to make a car learn to drive smoothly in minutes
https://towardsdatascience.com/learning-to-drive-smoothly-in-minutes-450a7cdb35f4
MIT License
284 stars 88 forks source link

OOM error when training VAE #38

Open kncrane opened 2 years ago

kncrane commented 2 years ago

Describe the bug

I'm running python -m vae.train --n-epochs 50 --verbose 0 --z-size 64 -f logs/images_generated_road_single_colour/ and getting an error Allocator (GPU_0_bfc) ran out of memory trying to allocate 1.17GiB (rounded to 1251090432).  after so many training iterations.

I'v added checkpoint saving using the save_checkpoint() function in vae/model.py. The training run was crashing after so many iterations when it tried to create the .meta file for the checkpoint but I got around that by adding write_meta_graph=False to the saver.save function

When I was still getting an OOM error I added self.sess.graph.finalize() to the _init_session() function in vae/model.py to make the graph read only and catch any changes to the graph. An exception was raised from the line vae_controller.set_target_params() in vae/train.py, which in turn calls assign_ops.append(param.assign(loaded_p)) from within set_params() in vae/model.py

Was reading this article https://riptutorial.com/tensorflow/example/13426/use-graph-finalize---to-catch-nodes-being-added-to-the-graph and the memory leak I am getting sounds most like their third example .. "subtle (e.g. a call to an overloaded operator on a tf.Tensor and a NumPy array, which implicitly calls tf.convert_to_tensor() and adds a new tf.constant() to the graph)."

Did you run into any OOM errors from graph growth when you were running these scripts or do you have any insights? Cheers Antonin

Code example This is my training loop section from vae/train.py (validation been added). The last line is the problem line ..

for epoch in range(args.n_epochs):
    print("Training ...")
    pbar = tqdm(total=len(train_minibatchlist))
    for obs in train_data_loader:
        feed = {vae.input_tensor: obs}
        (train_loss, r_loss, kl_loss, train_step, _) = vae.sess.run([
            vae.loss,
            vae.r_loss,
            vae.kl_loss,
            vae.global_step,
            vae.train_op
        ], feed)
        pbar.update(1)
    pbar.close()

    print("Evaluating ...")
    pbar = tqdm(total=len(val_minibatchlist))
    for obs in val_data_loader:
        feed = {vae.input_tensor: obs}
        (val_loss, val_r_loss, val_kl_loss) = vae.sess.run([
            vae.loss,
            vae.r_loss,
            vae.kl_loss
        ], feed)
        pbar.update(1)
    pbar.close()
    print("Epoch {:3}/{}".format(epoch + 1, args.n_epochs))
    print("Optimization Step: ", (train_step + 1), ", Loss: ", train_loss, " Validation Loss: ", val_loss)

    # Update params
    vae_controller.set_target_params()

This is the edited _init_session() from vae/model.py ..

    def _init_session(self):
        """Launch tensorflow session and initialize variables"""
        self.sess = tf.Session(graph=self.graph)
        self.sess.run(self.init)
        self.sess.graph.finalize()

And this is the supposed source of memory leak, within vae/model.py ..

    def set_params(self, params):
        assign_ops = []
        for param, loaded_p in zip(self.params, params):
            assign_ops.append(param.assign(loaded_p))
        self.sess.run(assign_ops)

Let me know if you want full scripts

System Info Describe the characteristic of your environment:

Additional context This is the full error message

File "/home/b3024896/dolphinsstorage/dolphin-tracking/DonkeyTrack/vae/train.py", line 157, in <module>
    vae_controller.set_target_params()
  File "/home/b3024896/dolphinsstorage/dolphin-tracking/DonkeyTrack/vae/controller.py", line 105, in set_target_params
    self.target_vae.set_params(params)
  File "/home/b3024896/dolphinsstorage/dolphin-tracking/DonkeyTrack/vae/model.py", line 208, in set_params
    assign_ops.append(param.assign(loaded_p))
  File "/home/b3024896/.local/share/virtualenvs/b3024896-7tinQHxi/lib/python3.6/site-packages/tensorflow/python/ops/variables.py", line 1952, in assign
    name=name)
  File "/home/b3024896/.local/share/virtualenvs/b3024896-7tinQHxi/lib/python3.6/site-packages/tensorflow/python/ops/state_ops.py", line 227, in assign
    validate_shape=validate_shape)
  File "/home/b3024896/.local/share/virtualenvs/b3024896-7tinQHxi/lib/python3.6/site-packages/tensorflow/python/ops/gen_state_ops.py", line 66, in assign
    use_locking=use_locking, name=name)
  File "/home/b3024896/.local/share/virtualenvs/b3024896-7tinQHxi/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 527, in _apply_op_helper
    preferred_dtype=default_dtype)
  File "/home/b3024896/.local/share/virtualenvs/b3024896-7tinQHxi/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1224, in internal_convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
  File "/home/b3024896/.local/share/virtualenvs/b3024896-7tinQHxi/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py", line 305, in _constant_tensor_conversion_function
    return constant(v, dtype=dtype, name=name)
  File "/home/b3024896/.local/share/virtualenvs/b3024896-7tinQHxi/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py", line 246, in constant
    allow_broadcast=True)
  File "/home/b3024896/.local/share/virtualenvs/b3024896-7tinQHxi/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py", line 290, in _constant_impl
    name=name).outputs[0]
  File "/home/b3024896/.local/share/virtualenvs/b3024896-7tinQHxi/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/b3024896/.local/share/virtualenvs/b3024896-7tinQHxi/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3588, in create_op
    self._check_not_finalized()
  File "/home/b3024896/.local/share/virtualenvs/b3024896-7tinQHxi/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3225, in _check_not_finalized
raise RuntimeError("Graph is finalized and cannot be modified.")
RuntimeError: Graph is finalized and cannot be modified.
araffin commented 2 years ago

Hello,

Did you run into any OOM errors from graph growth when you were running these scripts or do you have any insights?

I've never experienced any OOM (I was using Google Colab to train the VAE) and I haven't use that code for a while now (I've switched to PyTorch).

The only thing I can provide you is the code I'm now using to train the AE (just made it public, I should open source the rest too): https://github.com/araffin/aae-train-donkeycar

kncrane commented 2 years ago

I've been intending to switch over to PyTorch for months. Ok thank you will check it out!