tensorflow / fold

Deep learning with dynamic computation graphs in TensorFlow
Apache License 2.0
1.83k stars 266 forks source link

How to use images as input #14

Closed fabioasdias closed 7 years ago

fabioasdias commented 7 years ago

Replace each input word with a different image on that cute animation on the first page. Replace 'embed' with a tf.nn.conv2d. That's what I want to try. Can it be done?

I tried to td.Function(tf.nn.conv2d) but it expects more arguments...

delesley commented 7 years ago

Yes, you need to supply the extra arguments. :-)

td.Function(lambda x: tf.nn.conv2d(x, filter, strides, padding)) for some choice of filter/strides/padding. See https://www.tensorflow.org/api_docs/python/tf/nn/conv2d.

You'll also have to use some InputTransform to decode the image data into the format that conv2d is expecting.

On Wed, Feb 22, 2017 at 5:50 PM, Fábio Dias notifications@github.com wrote:

Replace each input word with a different image on that cute animation on the first page. Replace 'embed' with a tf.nn.conv2d. That's what I want to try. Can it be done?

I tried to td.Function(tf.nn.conv2d) but it expects more arguments...

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/tensorflow/fold/issues/14, or mute the thread https://github.com/notifications/unsubscribe-auth/AGGbTXPKORUCgEJVRCtrjBXqstIleySQks5rfOX_gaJpZM4MJb5Y .

-- DeLesley Hutchins | Software Engineer | delesley@google.com | 505-206-0315

fabioasdias commented 7 years ago

Indeed that makes it better, thanks. However:

def _kernel(name,shape):
    var = tf.get_variable(name, shape, initializer=tf.truncated_normal_initializer(stddev=5e-2, dtype=tf.float32), dtype=tf.float32)
    return(var)

(...)

    inT = (td.Map(td.Tensor((24,10,1)) >>
                  td.Function(lambda x: tf.nn.conv2d(x,
                                                     filter=_kernel('weights',[3,3,1,32]),
                                                     strides=[1,1,1,1],
                                                     padding="SAME")) >>
                  td.Function(lambda x: tf.reshape(x,[-1,24*32*10]))))

gets me:

ValueError: Variable weights already exists, disallowed. Did you mean to set reuse=True in VarScope?

And my attempts with with tf.variable_scope('conv1') as scope: weren't any better, I'm guessing td uses scopes internally...

EDIT: I also tried to use td.ScopedLayer (similar to the example with rnn) but conv2d doesn't accept scope. I'm guessing there is some trick to glue all this together, but I'm not finding it.

fabioasdias commented 7 years ago

I managed to do it by wrapping the convolution stuff into a ScopedLayer:

def convLayer(inX,scope):
    with tf.variable_scope(scope) as sc:
        conv = tf.nn.conv2d(tf.reshape(inX,[-1,24,10,1]),
                            filter=_newVar('weights',[3,3,1,64]),
                            strides=[1, 1, 1, 1],
                            padding='VALID')
        biases = _newVar('biases', [64])
        pre_activation = tf.nn.bias_add(conv, biases)
        conv = tf.nn.relu(pre_activation, name=sc.name)
    return(conv)        
def train(batch_size=100):
    encod = (td.Map(td.Tensor((24,wav.shape[2])) >>
                    td.ScopedLayer(convLayer) >>
                    td.Function(lambda x: tf.reshape(x,[-1,64*(10-2)]))) >>
             td.Fold(td.Concat() >>
                     td.Function(td.FC(128)), td.FromTensor(tf.zeros(128))))
    model = ( encod  >>  td.Function(td.FC(1)))

(...)

At least the code is running, I'm not entirely sure that it will do what I want :)

delesley commented 7 years ago

TensorFlow fold doesn't use scopes internally, but tf.get_variable does use scopes, and you're wrapping a call to get_variable into a lambda and then passing the lambda to the fold library. The lambda is not invoked when you call td.Function, it's invoked later, when we build the loom, in whatever variable scope happens to active. That's a recipe for name and scope collisions.

As you discovered, the td.Layers library provides some utility classes for dealing with scopes. ScopedLayer will pass an additional "scope" argument that you can use to create variables, as you do in your code. Alternatively, TensorToTensorLayer lets you override two variables: _create_variables and _process_batch, so you can separate the two steps.

On Thu, Feb 23, 2017 at 3:30 PM, Fábio Dias notifications@github.com wrote:

Indeed that makes it better. However:

def _kernel(name,shape): var = tf.get_variable(name, shape, initializer=tf.truncated_normal_initializer(stddev=5e-2, dtype=tf.float32), dtype=tf.float32) return(var)

(...)

inT = (td.Map(td.Tensor((24,10,1)) >>
              td.Function(lambda x: tf.nn.conv2d(x,
                                                 filter=_kernel('weights',[3,3,1,32]),
                                                 strides=[1,1,1,1],
                                                 padding="SAME")) >>
              td.Function(lambda x: tf.reshape(x,[-1,24*32*10]))))

gets me:

ValueError: Variable weights already exists, disallowed. Did you mean to set reuse=True in VarScope?

And my attempts with with tf.variable_scope('conv1') as scope: weren't any better, I'm guessing td uses scopes internally...

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/tensorflow/fold/issues/14#issuecomment-282156052, or mute the thread https://github.com/notifications/unsubscribe-auth/AGGbTS8GEdavalu_ybgGxQgdrTTSMV_Nks5rfhazgaJpZM4MJb5Y .

-- DeLesley Hutchins | Software Engineer | delesley@google.com | 505-206-0315