cybertronai / gradient-checkpointing

Make huge neural nets fit in memory
MIT License
2.69k stars 270 forks source link

Does this work with `fit_generator`? #41

Open ghost opened 5 years ago

ghost commented 5 years ago

@yaroslavvb Does this work with Keras fit_generator? I saw you used fit, but do you know if it will work with fit_generator?

yaroslavvb commented 5 years ago

No idea, haven't really used Keras

ghost commented 5 years ago

@yaroslavvb Just so I know, is a neural net supposed to train slower when using your memory_saving_gradients, or does it train faster?

yaroslavvb commented 5 years ago

I wish it could train faster and use less memory. Alas, you trade compute for memory. There are concrete numbers in the README.me

ghost commented 5 years ago

@yaroslavvb I've tried downgrading from tf-1.8 to 1.5 and still can't get it to work. I'm on Windows 10 and my task manager doesn't show any less memory being utilized when I use memory_saving_gradients.

Right now, I am on tensorflow 1.5 with keras 2.1.6 using python 3.5 x64-bit. I make sure to use the tensorflow implementation of keras backend (from tensorflow.python.keras._impl.keras import backend as K) as well as the tensorflow keras backend modules for keras layers.

I define my model, add gradient checkpointing for several convolutional and fully-connected layers, then compile the model in a function called get_model.

Here is all my code. I haven't put down a bunch of my pandas functions for dataset manipulations, but if for some reason you think they'd be important let me know and I'll post them here. Here is the meat of my code. I don't feel like I'm doing anything too out of the ordinary, but still can't get your code to work.

This model still takes up the same memory as it had before and takes just as long per epoch, with or without memory_saving_gradients as I've written it. Can you take a look?

# -*- coding: utf-8 -*-

##########
#LIBRARIES
##########

#Future
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import pandas as pd

pd.set_option('chained_assignment',None) #Sets `SettingWithCopyWarning` to None. If
                                         # making a chained assignment, the outcome may
                                         # vary depnding on if the data is a view of
                                         # other data or a copy of other data.

import cv2

import os
import time
import argparse
import h5py
import gc

import multiprocessing as mp

import tensorflow as tf
from tensorflow.python.keras._impl.keras import backend as K

from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat
from tensorflow.contrib.data.python.ops.batching import map_and_batch

import memory_saving_gradients

Dataset = tf.data.Dataset

from tensorflow.python.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from tensorflow.python.keras.models import Sequential, Model, load_model, model_from_yaml
from tensorflow.python.keras.callbacks import LearningRateScheduler, ModelCheckpoint, EarlyStopping, History, TensorBoard
from tensorflow.python.keras import regularizers, optimizers
from tensorflow.python.keras.layers import Conv2D, Dense, Flatten, Dropout, Input, Lambda, Activation

##################
#GLOBAL VARIABLES
##################

img_shape_raw = (3, 160, 320)

batch_size = 32

num_epochs = 1

crop_top = 70
crop_btm = 25

img_format = 'channels_first'
K.set_image_data_format(img_format)

img_shape_input = (img_shape_raw[0],
                   img_shape_raw[1] - crop_top - crop_btm,
                   img_shape_raw[2]) #(3, 65, 320)

max_procs = mp.cpu_count() - 1 or 1 # 4 physical cores, 8 logical cores
max_q_size = batch_size

root = r'.'

fldr_img_raw = os.path.join( root, r'dat\raw' )
fldr_csv_raw = os.path.join( root, r'dat\raw' )

fldr_img_mod = os.path.join( root, r'dat\mod' )
fldr_csv_mod = os.path.join( root, r'dat\mod' )

train_csv = os.path.join(fldr_csv_mod, 'training_data.csv')
val_csv = os.path.join(fldr_csv_mod, 'validation_data.csv')
test_csv = os.path.join(fldr_csv_mod, 'test_data.csv')

pth_bins_fl = os.path.join( fldr_csv_mod, 'bins.txt' )

fldr_fig = os.path.join( root, r'fig' )

lr = [1e-4, ]
run = [1, ]

hparam_str = ['1e-4', ]

fldr_log = os.path.join( root, r'log', hparam_str[0], 'run_{:04d}'.format(run[0]))

fldr_arch = os.path.join( root, r'arch' )
fldr_wt = os.path.join( root, r'wt' )
fldr_ckpt = os.path.join( root, r'ckpt' )
fldr_mdl = os.path.join( root, r'mdl' )

fldr_summary = os.path.join( root, r'summary' )

fl_fmt_wt_ckpt = os.path.join( fldr_ckpt,
                               r'wt_ckpt-run_{run:04d}'.format(run=run[0]) + '_epoch_{epoch:04d}_val_mse_{val_mean_squared_error:.7f}.h5' )

################
#DATA GENERATOR
################

def get_data( keep_ptl = 75 ):
    '''This just returns the train, validation, and test dataframes
       keeping a certain percentile of the original data. I'm not
       including it here for space and since it doesn't seem pertinent.
    '''

def generator_from_df( df, batch_size, shuffle = True ):

    def read( img_pth, angle ):

        im_fl = tf.read_file( img_pth )
        im = tf.image.decode_image(im_fl, channels=3)
        im = tf.transpose( im, [2, 0, 1] ) # Make image channels first

        return Dataset.from_tensors( (im, angle) )

    img_pths = tf.convert_to_tensor( df['Image_Path'].values )
    angs = tf.convert_to_tensor( df['Angle'].values )

    ds = Dataset.from_tensor_slices( (img_pths, angs) )

    ds = ds.apply( tf.contrib.data.parallel_interleave( read, cycle_length = batch_size, sloppy = True ) )

    if shuffle:
        ds = ds.apply( shuffle_and_repeat( buffer_size = 2*batch_size, count = num_epochs ) )
    else:
        ds = ds.repeat( num_epochs )

    ds = ds.apply( map_and_batch(
        lambda img_pth, ang: (img_pth,ang),
        batch_size,
        num_parallel_batches = max_procs ) )

    ds = ds.prefetch( max_procs )

    iterator = ds.make_one_shot_iterator()
    sess = K.get_session()

    next_element = iterator.get_next()

    while True:

        try:
          yield sess.run(next_element)
        except tf.errors.OutOfRangeError:
          break

###########
#GET MODEL
###########

def get_model( lr ):

    keep_prob = 0.5
    rate = keep_prob

    l2 = regularizers.l2(0.001)

    with tf.name_scope('Input'):
        inputs = Input( shape=img_shape_input, name='input' )

        x = Lambda(lambda x: x / 255. - 0.5,
                   input_shape=img_shape_input, name = 'norm_-0.5_to_0.5')(inputs)

    with tf.name_scope('Hidden_Layers'):

        with K.name_scope('ConvLayer_01'):

            x = Conv2D(4, (5,5),
                       kernel_regularizer=l2,
                       bias_regularizer=l2,
                       padding='same',
                       name='conv01')(x)

        with tf.name_scope('ConvLayer_02'):

            x = Conv2D(12, (5,5),
                       kernel_regularizer=l2,
                       bias_regularizer=l2,
                       padding='same',
                       name='conv02')(x)

        with tf.name_scope('ConvLayer_03'):

            x = Conv2D(24, (5,5),
                       kernel_regularizer=l2,
                       bias_regularizer=l2,
                       padding='same',
                       name='conv03')(x)

        with tf.name_scope('ConvLayer_04'):

            x = Conv2D(24, (3,3),
                       kernel_regularizer=l2,
                       bias_regularizer=l2,
                       padding='same',
                       name='conv04')(x)

        with tf.name_scope('ConvLayer_05'):

            x = Conv2D(32, (3,3),
                       kernel_regularizer=l2,
                       bias_regularizer=l2,
                       padding='same',
                       name='conv05')(x)

        with tf.name_scope('Flatten'):

            x = Flatten(name='flatten')(x)

        with tf.name_scope('FullyConnectedLayer_01'):

            x = Dense(100,
                      kernel_regularizer=l2,
                      bias_regularizer=l2,
                      name='fc01')(x)

        with tf.name_scope('FullyConnectedLayer_02'):

            x = Dense(50,
                      kernel_regularizer=l2,
                      bias_regularizer=l2,
                      name='fc02')(x)

        with tf.name_scope('FullyConnectedLayer_03'):

            x = Dense(25,
                      kernel_regularizer=l2,
                      bias_regularizer=l2,
                      name='fc03')(x)

        with tf.name_scope('FullyConnectedLayer_04'):

            x = Dense(10,
                      kernel_regularizer=l2,
                      bias_regularizer=l2,
                      name='fc04')(x)

    with tf.name_scope('Output'):

        outputs = Dense(1,
                        name='output')(x)

    # Create Model

    model = Model( inputs = inputs, outputs = outputs )

    adam = optimizers.Adam( lr = lr, decay = 0.001 ) # Learning rate and decay set in LearningRateScheduler

    # Memory Saving Gradients

    layer_names = [ 'conv02', 'conv04', 'fc01', 'fc03' ]

    [tf.add_to_collection('checkpoints', model.get_layer(l).get_output_at(0))
     for l in layer_names]

    K.__dict__['gradients'] = memory_saving_gradients.gradients_collection

    # Compile Model

    model.compile(loss='mean_squared_error', optimizer=adam, metrics=['mse'])

    return model

class CumulativeHistory( History ):
    '''
    History does not allow resume history, but this does.
    '''
    def on_train_begin( self, logs=None ):
        if not hasattr(self, 'epoch'):
            super(CumulativeHistory, self).on_train_begin( logs )

def main(*args, **kargs):
    """ Behavioral Cloning Project
    """

    parser = argparse.ArgumentParser(description='Behavioral Cloning Project')

    parser.add_argument('-c', '--checkpoint', type=str, help='Checkpoint (`.h5` file)')
    parser.add_argument('-e', '--epoch', type=int, help='Initial epoch')

    args = parser.parse_args()

    model_type = 'new'
    train_model = None
    initial_epoch = 0

    if args.checkpoint is not None:

        train_model = load_model( args.checkpoint )

        initial_epoch = args.epoch

        model_type = 'loaded'

    # Set Configuration

    config = tf.ConfigProto( intra_op_parallelism_threads = max_procs,
                             inter_op_parallelism_threads = 0) # set automatically to number of logical cores

    config.gpu_options.allow_growth = True

    # Get Data

    df_train, df_val, df_test, bins = get_data( keep_ptl = 60 )

    ntrain, nval, ntest = df_train.shape[0], df_val.shape[0], df_test.shape[0]

    # Training

    train_graph = tf.Graph()

    train_generator = generator_from_df( df_train, batch_size )
    val_generator   = generator_from_df( df_val,   batch_size, shuffle=False )

    nbatches_train = ntrain // batch_size
    nbatches_val   = nval // batch_size

    history = CumulativeHistory()

    early_stop = EarlyStopping( monitor='val_mean_squared_error',
                                min_delta=1e-4,
                                patience=50,
                                verbose=0,
                                mode='min')

    model_ckpt = ModelCheckpoint( fl_fmt_wt_ckpt,
                                  monitor='val_mean_squared_error',
                                  verbose=0,
                                  save_best_only=True,
                                  save_weights_only=True,
                                  period=1)

    callbacks = [history, early_stop, model_ckpt]

    for i in range(len(lr)):

        train_sess = tf.Session( config = config, graph = train_graph )
        K.set_session( train_sess )

        if model_type == 'new':

            with train_graph.as_default():

                # Print model summary
                summary_fl_pth = os.path.join( fldr_summary, 'model_summary_run_{:04d}_'.format(run[0]) + r'.txt' )

                train_model = get_model( lr[i], is_training = True )

                with open(summary_fl_pth, 'w') as summary_file:
                    train_model.summary( print_fn=lambda x: summary_file.write(x + '\n') )

        with train_graph.as_default():

            with train_sess.as_default():

                if K.backend() == 'tensorflow':

                    board = TensorBoard( log_dir = fldr_log,
                                         histogram_freq = 0,
                                         write_graph = True,
                                         write_images = True )
                    callbacks.append( board )

                writer = tf.summary.FileWriter( fldr_log, train_graph )

                ts = time.time()
                ts = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d_%H-%M-%S')

                arch_yaml = train_model.to_yaml()
                arch_fl_pth = os.path.join( fldr_arch, 'arch_' + hparam_str[0] + '_run_{:04d}_'.format(run[0]) + ts + '.yaml' )

                with open(arch_fl_pth, 'w') as arch_file:
                    arch_file.write( arch_yaml )

                train_model.save( os.path.join( fldr_mdl,
                                                'model_init_' + hparam_str[0] + '_run_{:04d}_'.format(run[0]) + ts + '.h5') )

                train_model.save_weights( os.path.join( fldr_wt,
                                                        'weights_init_' + hparam_str[0] + '_run_{:04d}_'.format(run[0]) + ts  + '.h5' ) )

                train_model.fit_generator(
                    generator = train_generator,
                    steps_per_epoch = nbatches_train,
                    epochs = num_epochs,
                    max_queue_size = max_q_size,
                    validation_data = val_generator,
                    validation_steps = nbatches_val,
                    workers = 0,
                    callbacks = callbacks,
                    initial_epoch = initial_epoch)

                ts = time.time()
                ts = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d_%H-%M-%S')

                train_model.save( os.path.join( fldr_mdl,
                                                'model_final_' + hparam_str[0] + '_run_{:04d}_'.format(run[0]) + ts + '.h5') )

                train_model.save_weights( os.path.join( fldr_wt,
                                                        'weights_final_' + hparam_str[0] + '_run_{:04d}_'.format(run[0]) + ts  + '.h5' ) )

        if K.backend() == 'tensorflow':
            K.clear_session()

        del train_model
        gc.collect()

if __name__ == '__main__':
    """ Entry point to the program
    """

    main()