keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.97k stars 19.46k forks source link

CuDNN RNN layers nested in TimeDistributed are not converted when loading #10356

Closed bzamecnik closed 6 years ago

bzamecnik commented 6 years ago

Making issue from comment https://github.com/keras-team/keras/issues/10080#issuecomment-394640409 at #10080.

When converting weights between CuDNNGRU and GRU which is wrapped in TimeDistributed the conversion is skipped by mistake. Similar to Bidirectional (#8908) and Model/Sequential (#10080).

Example of failure:

import tempfile
import os

from keras.layers import Input, GRU, CuDNNGRU, TimeDistributed
from keras.models import Model
import numpy as np

input2 = Input((4, 6, 10))
output_cudnn = TimeDistributed(CuDNNGRU(units))(input2)
output = TimeDistributed(GRU(units, activation='hard_sigmoid', reset_after=True))(input2)
model_cudnn = Model(input2, output_cudnn)
model_plain = Model(input2, output)

_, fname = tempfile.mkstemp('.h5')
model_cudnn.save_weights(fname)
model_plain.load_weights(fname)

Fails with:

local/lib/python2.7/site-packages/tensorflow/python/framework/ops.pyc in _create_c_op(graph, node_def, inputs, control_inputs)
   1568   except errors.InvalidArgumentError as e:
   1569     # Convert to ValueError for backwards compatibility.
-> 1570     raise ValueError(str(e))
   1571 
   1572   return c_op

ValueError: Shapes must be equal rank, but are 2 and 1 for 'Assign_11' (op: 'Assign') with input shapes: [2,6], [12].

Without TimeDistributed it works ok:

input2 = Input((6, 10))
output_cudnn = CuDNNGRU(units)(input2)
output = GRU(units, activation='hard_sigmoid', reset_after=True)(input2)
model_cudnn = Model(input2, output_cudnn)
model_plain = Model(input2, output)

_, fname = tempfile.mkstemp('.h5')
model_cudnn.save_weights(fname)
model_plain.load_weights(fname)

The same for the other direction (plain -> CuDNN).

Do we know of any other wrapper layers that need this conversion?

bzamecnik commented 6 years ago

@Ajk4 Fixed. Waiting for CI and review. Works for me. Thanks for the report.

Ajk4 commented 6 years ago

Thanks!

bzamecnik commented 6 years ago

Merged.