daniel-kukiela / nmt-chatbot

NMT Chatbot
GNU General Public License v3.0
387 stars 214 forks source link

'AttentionWrapper' object has no attribute 'zero_state' #146

Open Neel125 opened 4 years ago

Neel125 commented 4 years ago

def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state, source_sequence_length): """Build a RNN cell with attention mechanism that can be used by decoder."""

No Attention

if not self.has_attention:
    return super(AttentionModel, self)._build_decoder_cell(
        hparams, encoder_outputs, encoder_state, source_sequence_length)
elif hparams["attention_architecture"] != "standard":
    raise ValueError(
        "Unknown attention architecture %s" % hparams["attention_architecture"])

num_units = hparams["num_units"]
num_layers = self.num_decoder_layers
num_residual_layers = self.num_decoder_residual_layers
infer_mode = hparams["infer_mode"]

dtype = tf.float32

# Ensure memory is batch-major
if self.time_major:
    memory = tf.transpose(encoder_outputs, [1, 0, 2])
else:
    memory = encoder_outputs

if (self.mode == tf.estimator.ModeKeys.PREDICT and
        infer_mode == "beam_search"):
    memory, source_sequence_length, encoder_state, batch_size = (
        self._prepare_beam_search_decoder_inputs(
            hparams["beam_width"], memory, source_sequence_length,
            encoder_state))
else:
    batch_size = self.batch_size

# Attention
attention_mechanism = self.attention_mechanism_fn(
    hparams["attention"], num_units, memory, source_sequence_length, self.mode)

cell = model_helper.create_rnn_cell(
    unit_type=hparams["unit_type"],
    num_units=num_units,
    num_layers=num_layers,
    num_residual_layers=num_residual_layers,
    forget_bias=hparams["forget_bias"],
    dropout=hparams["dropout"],
    num_gpus=self.num_gpus,
    mode=self.mode,
    single_cell_fn=self.single_cell_fn)

# Only generate alignment in greedy INFER mode.
alignment_history = (self.mode == tf.estimator.ModeKeys.PREDICT and
                     infer_mode != "beam_search")
cell = tfa.seq2seq.AttentionWrapper(
    cell,
    attention_mechanism,
    attention_layer_size=num_units,
    alignment_history=alignment_history,
    output_attention=hparams["output_attention"],
    name="attention")

# TODO(thangluong): do we need num_layers, num_gpus?
device = tf.device(model_helper.get_device_str(num_layers-1, self.num_gpus))

cell = tf.nn.rnn_cell.DeviceWrapper(cell,
                                    device)
cell = tf.nn.rnn_cell.DropoutWrapper(cell, input_keep_prob=0.8)
if hparams["pass_hidden_state"]:
    decoder_initial_state = cell.zero_state(batch_size=batch_size*hparams["beam_width"], dtype=dtype).clone(
        cell_state=encoder_state)
else:
    decoder_initial_state = cell.zero_state(batch_size=batch_size*hparams["beam_width"], dtype=dtype)

return cell, decoder_initial_state

Error: File "/home/ml-ai4/Neel-dev023/ChatBot/nmt-chatbot/nmt/nmt/attention_model.py", line 144, in _build_decoder_cell decoder_initial_state = cell.zero_state(batch_size=batch_size*hparams["beam_width"], dtype=dtype).clone( File "/home/ml-ai4/Neel-dev023/ChatBot/nmt-chatbot/venv/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn_cell_wrapper_impl.py", line 199, in zero_state return self.cell.zero_state(batch_size, dtype) File "/home/ml-ai4/Neel-dev023/ChatBot/nmt-chatbot/venv/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn_cell_wrapper_impl.py", line 431, in zero_state return self.cell.zero_state(batch_size, dtype) AttributeError: 'AttentionWrapper' object has no attribute 'zero_state'