tensorflow / nmt

TensorFlow Neural Machine Translation Tutorial
Apache License 2.0
6.36k stars 1.96k forks source link

Number of trainable parameters? #365

Open lvcasgm opened 6 years ago

lvcasgm commented 6 years ago

Hello everyone!

I'd like to know how to compute the number of trainable parameters in my network with the configuration I am using.

I've read here that the way to compute this number is by calling these functions:

def count_number_trainable_params():
    '''
    Counts the number of trainable variables.
    '''
    tot_nb_params = 0
    for trainable_variable in tf.trainable_variables():
        shape = trainable_variable.get_shape() # e.g [D,F] or [W,H,C]
        current_nb_params = get_nb_params_shape(shape)
        tot_nb_params = tot_nb_params + current_nb_params
    return tot_nb_params

def get_nb_params_shape(shape):
    '''
    Computes the total number of params for a given shap.
    Works for any number of shapes etc [D,F] or [W,H,C] computes D*F and W*H*C.
    '''
    nb_params = 1
    for dim in shape:
        nb_params = nb_params*int(dim)
    return nb_params 

However I don't know how and where to call these functions. From which function/file should I call count_number_trainable_params()? How can I print the result in a text file?

Thanks a lot for your time!

luozhouyang commented 6 years ago

You can call these functions under the session where you build your network.

with tf.Session() as sess:
    print("Number of trainable parameters: %d" % count_number_trainable_params())