Closed Lucklyric closed 7 years ago
someone reported it previously, it is a bug of numpy, so that guy contributed save_npz_dict
, you can have a look.
save_npz_dict() solved. I found a reference post from StackOverflow might give some points.
If use load_and_assign_npz_dict
to restore the parameters saved by save_npz_dict
.
In the source code, it only restores the parameters can be found with FLAG tf.GraphKeys.TRAINABLE_VARIABLES
. So, it can not restore the parameters such as (moving_mean, moving_variance) in BatchNormalization Layer.
How about replacing it with "varlist = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=key)"
def load_and_assign_npz_dict(name='model.npz', sess=None):
"""Restore the parameters saved by ``tl.files.save_npz_dict()``.
Parameters
----------
name : a string
The name of the .npz file.
sess : Session
"""
assert sess is not None
params = np.load(name)
if len(params.keys()) != len(set(params.keys())):
raise Exception("Duplication in model npz_dict %s" % name)
ops = list()
for key in params.keys():
try:
# tensor = tf.get_default_graph().get_tensor_by_name(key)
# varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=key)
# How about check the var in GLOBAL_VARIABLES ?
varlist = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=key)
if len(varlist) > 1:
raise Exception("[!] Multiple candidate variables to be assigned for name %s" % key)
elif len(varlist) == 0:
raise KeyError
else:
ops.append(varlist[0].assign(params[key]))
print("[*] params restored: %s" % key)
except KeyError:
print("[!] Warning: Tensor named %s not found in network." % key)
sess.run(ops)
print("[*] Model restored from npz_dict %s" % name)
In inference model, with the trainable flag as False, then (beta, gamma) parameters in BN also cannot be restored.
Oh, thank you very much, it would be great if you can make a push request, or you want me to modify it?
When a network's size of hidden unites same as its input -1 dimension, tl.files.save_npz() causes ValueError.