tensorflow / fold

Deep learning with dynamic computation graphs in TensorFlow
Apache License 2.0
1.83k stars 266 forks source link

TypeError for td.OneOf #70

Open nitu0317 opened 7 years ago

nitu0317 commented 7 years ago

I am using td.OneOf to set up terminal condition for recursion, which means I will implement two different blocks based on whether I get empty list or not.

td.OneOf(lambda pair: pair[0] == [], (add_metrics(is_train, is_empty=False), add_metrics(is_train, is_empty=True)))

But I get "TypeError: bad output type PyObjectType for <td.Composition.input 'metrics'>, expected TupleType". It seems that a PyObjectType is fed into add_metrics, but I am pretty sure that the output before td.OneOf is TupleType.

Thanks!

chao-su commented 6 years ago

Hi, I think I solved the problem.

The basic usage of td.OneOf is: td.OneOf(key_fn, case_blocks), where key_fn is either a python function or a block. In a python function case, the Fold transforms it into a td.InputTransform block with PyObjectType inputs and outputs. Remember that all the input types of key_fn, case_blocks, and td.AllOf must keep the same, because the output before td.OneOf is also fed into key_fn as input. When initiating td.OneOf, Fold propagates the input type from key_fn block to td.OneOf and then case_blocks. See line 1628 of blocks.py So, here is the contradiction for case_blocks between your input type TupleType and the expected PyObjectType.

See the following example. a=['1',['3', '4']] b=['2',['3', '4']] block1 = (td.Scalar('int32'), td.Scalar('int32')) block2 = td.Function(tf.add) block3 = td.Function(tf.multiply) oneof = (td.Identity(), block1) >> td.OneOf(key_fn=lambda x: x=='1', case_blocks=(td.GetItem(1) >> block2, td.GetItem(1) >> block3))

We will get "TypeError: bad input type PyObjectType for , expected TupleType or TensorType". It can be solved through this:

oneof = (td.Identity(), block1) >> td.OneOf(key_fn=(td.GetItem(0) >> td.InputTransform(lambda x: x=='1')), case_blocks=(td.GetItem(1) >> block2, td.GetItem(1) >> block3)) oneof.eval(a) => array(12, dtype=int32) oneof.eval(b) => array(7, dtype=int32)