tensorflow / fold

Deep learning with dynamic computation graphs in TensorFlow
Apache License 2.0
1.83k stars 266 forks source link

Need help to fold the NaryTreeLSTMCell #74

Open bdqnghi opened 7 years ago

bdqnghi commented 7 years ago

This is my attempt to implement a real N-ary version of the TreeLSTM instead of the binary version from this example: https://github.com/tensorflow/fold/blob/master/tensorflow_fold/g3doc/sentiment.ipynb

This is my attempt:

def logits_and_state():
  """Creates a block that goes from tokens to (logits, state) tuples."""
  word2vec = (td.GetItem(0) >> td.InputTransform(lookup_word) >>
              td.Scalar('int32') >> word_embedding)

  children_num = 
  children2vec_list = list()
  children2vec_list.append(embed_subtree())
  for i in range(children_num):
    children2vec_list.append(embed_subtree())

  children2vec = tuple(children2vec_list)

  # Trees are binary, so the tree layer takes two states as its input_state.

  zero_state = td.Zeros((tree_lstm.state_size,) * 2)
  # Input is a word vector.
  zero_inp = td.Zeros(word_embedding.output_type.shape[0])

  # word_case = 
  word_case = td.AllOf(word2vec, zero_state)
  children_case = td.AllOf(zero_inp, children2vec)

  tree2vec = td.OneOf(lambda x: 1 if len(x) == 1 else 2), [(1,word_case),(2,children_case)])
  return tree2vec >> tree_lstm >> (output_layer, td.Identity())

The children_num is the thing that I'm struggling at this moment, I have no idea to get out that number, eventhought I know that the length of children can be obtained by td.GetItem(1) ==> will produced a block that contains an array of children ==> how to get out the real number of that block?

KazuhiraDZ commented 7 years ago

i have the same problem as you. Have you finished it yet?

vivian1993 commented 6 years ago

i also have the same problem as you. Have both of you finished it yet?