tensorflow / addons

Useful extra functionality for TensorFlow 2.x maintained by SIG-addons
Apache License 2.0
1.69k stars 611 forks source link

Unpooling layer in tensorflow #632

Closed ziky90 closed 3 years ago

ziky90 commented 8 years ago

It would be nice to have in TensorFlow also the unpooling layer as it is described in the paper on deconvolution networks: http://cvlab.postech.ac.kr/research/deconvnet/

I was googling a bit and I found that the added unpooling layer would be handful also for others: http://stackoverflow.com/questions/36548736/tensorflow-unpooling

geelenb commented 7 years ago

Why could tf.gather not be used for this problem? Does it not propagate gradients?

sirajulsalekin commented 7 years ago

Hi @hermitman : Were you able to solve the gradients not backpropagating problem? I used @EmmaBYPeng's code and faced the same issue as yours. The weights don't update during backpropagation after the unpooling layer.

dbarnes commented 7 years ago

@hermitman @sirajulsalekin For whatever reason I could also not backprop through @fabianbormann or @EmmaBYPeng unpool code (although they was able to). If anyone comes across the same problem my solution was to replace the sparse_tensor_to_dense op as mentioned in https://github.com/tensorflow/tensorflow/issues/6391 as sparse_tensor_to_dense doesnt have a gradient (0.11)

Naturally this doesnt change the speed problems but as far as I can tell only a dedicated op as mentioned by @girving and @syed-ahmed will solve that

Namely

# Original
return tf.sparse_tensor_to_dense(tf.sparse_reorder(delta))

# New
# https://github.com/tensorflow/tensorflow/issues/6391
return tf.sparse_add(tf.zeros(tf.to_int32(delta.shape)), tf.sparse_reorder(delta))
Pepslee commented 7 years ago

much faster

def unpool_layer2x2_batch(bottom, argmax):
    bottom_shape = tf.shape(bottom)
    top_shape = [bottom_shape[0], bottom_shape[1] * 2, bottom_shape[2] * 2, bottom_shape[3]]

    batch_size = top_shape[0]
    height = top_shape[1]
    width = top_shape[2]
    channels = top_shape[3]

    argmax_shape = tf.to_int64([batch_size, height, width, channels])
    argmax = unravel_argmax(argmax, argmax_shape)

    t1 = tf.to_int64(tf.range(channels))
    t1 = tf.tile(t1, [batch_size * (width // 2) * (height // 2)])
    t1 = tf.reshape(t1, [-1, channels])
    t1 = tf.transpose(t1, perm=[1, 0])
    t1 = tf.reshape(t1, [channels, batch_size, height // 2, width // 2, 1])
    t1 = tf.transpose(t1, perm=[1, 0, 2, 3, 4])

    t2 = tf.to_int64(tf.range(batch_size))
    t2 = tf.tile(t2, [channels * (width // 2) * (height // 2)])
    t2 = tf.reshape(t2, [-1, batch_size])
    t2 = tf.transpose(t2, perm=[1, 0])
    t2 = tf.reshape(t2, [batch_size, channels, height // 2, width // 2, 1])

    t3 = tf.transpose(argmax, perm=[1, 4, 2, 3, 0])

    t = tf.concat(4, [t2, t3, t1])
    indices = tf.reshape(t, [(height // 2) * (width // 2) * channels * batch_size, 4])

    x1 = tf.transpose(bottom, perm=[0, 3, 1, 2])
    values = tf.reshape(x1, [-1])
    return tf.scatter_nd(indices, values, tf.to_int64(top_shape))
dbarnes commented 7 years ago

Thanks @Pepslee, I was going to try tf.scatter_nd as well but unfortunately im stuck to 0.11 at the moment which doesnt have it.

Pepslee commented 7 years ago

I tried this function at the master branch of tensorflow

sirajulsalekin commented 7 years ago

Thanks a lot @danbarnes333. I tried your code and it worked ! I will run @Pepslee's code too to see the difference.

Pepslee commented 7 years ago

even faster

def unravel_argmax(argmax, shape):
    argmax_shape = argmax.get_shape()
    new_1dim_shape = tf.shape(tf.constant(0, shape=[tf.Dimension(4), argmax_shape[0]*argmax_shape[1]*argmax_shape[2]*argmax_shape[3]]))
    batch_shape = tf.constant(0, dtype=tf.int64, shape=[argmax_shape[0], 1, 1, 1]).get_shape()
    b = tf.multiply(tf.ones_like(argmax), tf.reshape(tf.range(shape[0]), batch_shape))
    y = argmax // (shape[2] * shape[3])
    x = argmax % (shape[2] * shape[3]) // shape[3]
    c = tf.ones_like(argmax) * tf.range(shape[3])
    pack = tf.stack([b, y, x, c])
    pack = tf.reshape(pack, new_1dim_shape)
    pack = tf.transpose(pack)
    return pack

def unpool(updates, mask, ksize=[1, 2, 2, 1]):
    input_shape = updates.get_shape()
    new_dim_y = input_shape[1] * ksize[1]
    new_dim_x = input_shape[2] * ksize[2]
    output_shape = tf.to_int64((tf.constant(0, dtype=tf.int64, shape=[input_shape[0], new_dim_y, new_dim_x, input_shape[3]]).get_shape()))
    indices = unravel_argmax(mask, output_shape)
    new_1dim_shape = tf.shape(tf.constant(0, shape=[input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3]]))
    values = tf.reshape(updates, new_1dim_shape)
    ret = tf.scatter_nd(indices, values, output_shape)
    return ret
amortazi commented 7 years ago

Hi, @Pepslee can you mention what is 'mask' in your unpool function? Also, can you give a simple example how I can use that function in my own program? I am trying to use it in deconvolution network? Thanks, Ali

Pepslee commented 7 years ago

Hi @amortazi , 'mask' - is the result of the tf.nn.max_pool_with_argmax(input=image, ksize=ksize, strides=[1, 2, 2, 1], padding='SAME') operation. mask is the tensor of indices of max values of input_tensor. input_tensor is the input tensor of the maxpool operation. This operation ( unpool ) is inverted to the maxpool.

Pepslee commented 7 years ago

Remade code in one function

def unpool(updates, mask, ksize=[1, 2, 2, 1]):
    input_shape = updates.get_shape().as_list()
    #  calculation new shape
    output_shape = (input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3])
    # calculation indices for batch, height, width and feature maps
    one_like_mask = tf.ones_like(mask)
    batch_range = tf.reshape(tf.range(output_shape[0], dtype=tf.int64), shape=[input_shape[0], 1, 1, 1])
    b = one_like_mask * batch_range
    y = mask // (output_shape[2] * output_shape[3])
    x = mask % (output_shape[2] * output_shape[3]) // output_shape[3]
    feature_range = tf.range(output_shape[3], dtype=tf.int64)
    f = one_like_mask * feature_range
    # transpose indices & reshape update values to one dimension
    updates_size = tf.size(updates)
    indices = tf.transpose(tf.reshape(tf.stack([b, y, x, f]), [4, updates_size]))
    values = tf.reshape(updates, [updates_size])
    ret = tf.scatter_nd(indices, values, output_shape)
    return ret
girving commented 7 years ago

@zheng-xq Opinions on whether we should accept an unpooling layer that just calls scatter_nd? A native op would be faster but probably not tremendously faster, since fundamentally it is just doing a scatter.

amortazi commented 7 years ago

Hi, Thanks for the response @Pepslee. I run it and it is working well, but only problem is that it is slow. I checked it and it seems some operations (like // and %) can not be run on GPU. So, in the middle of my network they are passed to CPU and it is a reason for being slow. I am wondering if anyone has any suggestion for solving this problem. Also, I have asked this question in Stackoverflow : (http://stackoverflow.com/questions/41797875/running-tf-mod-and-tf-floor-div-in-tensorflow-in-gpu) Thanks

Pepslee commented 7 years ago

@girving Which one native op would be faster?

girving commented 7 years ago

@Pepslee A hypothetical native op that does the whole thing at once. Do you have some intuition for how that would compare speed-wise?

ivankreso commented 7 years ago

@amortazi tf.scatter_nd is CPU only also. I think the code from above can be rewritten without // and % but scatter is probably the main bottleneck.

amortazi commented 7 years ago

@ivankreso yes, tf.stack and tf .scatter_nd are both bottleneck. Currently, both of them are just in CPU. I am trying to register them in GPU, so we can remove those bottlenecks. Here are the issues about stack and scatter:
https://github.com/tensorflow/tensorflow/issues/7026 https://github.com/tensorflow/tensorflow/issues/7027

I highly appreciate if anyone can help! thanks

Pepslee commented 7 years ago

I can remake this code without tf.stack and //, % , but I can`t find the GPU analog of tf.scatter_nd

mshunshin commented 7 years ago

@Pepslee Many thanks for your code. If you are using None as the batch_size in the placeholder, batch_range = tf.reshape(tf.range(output_shape[0], dtype=tf.int64), shape=[input_shape[0], 1, 1, 1]) fails with an error in reshape. Is there anyway around this?

guyeng0 commented 7 years ago

@Pepslee @mshunshin I'm also encountering the same error in reshape when the batch_size is None. Did you manage to solve it?

yselivonchyk commented 7 years ago

I am not sure if it is the right place.

According to docs max_pool_with_argmax computes indexes using next formula: [b, y, x, c] -> ((b * height + y) * width + x) * channels + c

while infact batch index is ignored and next formula used: [b, y, x, c] -> (y * width + x) * channels + c

So, smth is wrong there.

Code by @Pepslee seems to take it into account.

Enet4 commented 7 years ago

I took a try at adapting the unpooling function to support a partially defined input shape (without the batch size), and here it is:

def unpool(updates, mask, ksize=2, name="unpool"):
    if isinstance(ksize, int):
        ksize = [1, ksize, ksize, 1]
    input_shape = updates.get_shape().as_list()
    #  calculation new shape
    output_shape = [input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3]]
    # calculation indices for batch, height, width and feature maps
    one_like_mask = tf.ones_like(mask)
    bsize = tf.to_int64(tf.shape(updates)[0])
    batch_range = tf.reshape(tf.range(bsize, dtype=tf.int64),
                             shape=[-1, 1, 1, 1])
    b = one_like_mask * batch_range
    y = mask // (output_shape[1] * output_shape[2])
    x = mask % (output_shape[1] * output_shape[2]) // output_shape[2]
    feature_range = tf.range(output_shape[2], dtype=tf.int64)
    f = one_like_mask * feature_range
    # transpose indices & reshape update values to one dimension
    updates_size = tf.size(updates)
    indices = tf.transpose(tf.reshape(tf.stack([b, y, x, f]), [4, updates_size]))
    values = tf.reshape(updates, [updates_size])
    ret = tf.scatter_nd(indices, values, tf.concat(
        [[bsize], tf.to_int64(output_shape)], axis=0))
    return ret

In the batch range, I simply took advantage of the -1 dimension. Then I fetched the batch size and built the final output shape dynamically (as a Tensor). Admittedly, I am not sure if I employed the fastest and most elegant operations, but it appears to be working on this side.

Panaetius commented 7 years ago

I think I found bug/issue with tf.nn.max_pool_with_argmax and the unpool workaround as presented here.

tf.nn.max_pool_with_argmax indices are calculated as (y w + x) channels + c, but the "w" the width of the input tensor, not the width of the input tensor + padding, if any padding (padding='SAME' and width of tensor being odd) is applied.

Using the unpool method, the width is calculated by dividing/modulo that output with input_shape[2] * ksize[2], with padding this will be 1 pixel bigger than the width that tf.nn.max_pool_with_argmax uses for its argmax output. So if a padding is applied, every row of the output image of the unpool() op will be slightly offset, leading to the image being slightly tilted.

I'm currently implementing SegNet, which has several unpool operations one after the other, each making the tilting worse if there was any padding for it, which is really noticeable when looking at the final output.

My workaround was to change the proposed unpool operation by simply adding an input-argument for the output shape as follows:

def unpool(updates, mask, ksize=[1, 2, 2, 1], output_shape=None, name=''):
    with tf.variable_scope(name):
        mask = tf.cast(mask, tf.int32)
        input_shape = tf.shape(updates, out_type=tf.int32)
        #  calculation new shape
        if output_shape is None:
            output_shape = (input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3])

        # calculation indices for batch, height, width and feature maps
        one_like_mask = tf.ones_like(mask, dtype=tf.int32)
        batch_shape = tf.concat([[input_shape[0]], [1], [1], [1]], 0)
        batch_range = tf.reshape(tf.range(output_shape[0], dtype=tf.int32), shape=batch_shape)
        b = one_like_mask * batch_range
        y = mask // (output_shape[2] * output_shape[3])
        x = (mask // output_shape[3]) % output_shape[2] #mask % (output_shape[2] * output_shape[3]) // output_shape[3]
        feature_range = tf.range(output_shape[3], dtype=tf.int32)
        f = one_like_mask * feature_range
        # transpose indices & reshape update values to one dimension
        updates_size = tf.size(updates)
        indices = tf.transpose(tf.reshape(tf.stack([b, y, x, f]), [4, updates_size]))
        values = tf.reshape(updates, [updates_size])
        ret = tf.scatter_nd(indices, values, output_shape)
        return ret

then when calling the op, I supply the shape of the convolution in the encoder part of segnet as output_shape, so the code will use the correct (well, incorrect...) width when transforming the tf.nn.max_pool_with_argmax indices.

Arguably, this is a bug with tf.nn.max_pool_with_argmax, since it should calculate the argmax indices by taking potential padding into account

PavlosMelissinos commented 7 years ago

@Panaetius Have you submitted an issue or a pull request for tf.nn.max_pool_with_argmax?

Your solution is a little bit dirty but seems to work perfectly. Are there any plans to merge your proposal or should this be further discussed?

Panaetius commented 7 years ago

@PavlosMelissinos No I haven't submitted an issue or pull request, I wanted to see feedback on this here first, since I wasn't 100% sure if it's actually a bug.

Yes, the code is a little dirty, it was just a quick fix for a hobby project. But it's been working great for me so far. I think fixing the max_pool_with_argmax issue should be done before adding any unpool op, and the code would have to be sanitized as well.

I'll probably write a small self-contained example to show the issue with max_pool_with_argmax and post a bug report later today.

PavlosMelissinos commented 7 years ago

Great! Agree with you on all points. Looking forward to it.

Pepslee commented 7 years ago

def unpool(pool, ind, ksize=[1, 2, 2, 1], scope='unpool'):
    """
       Unpooling layer after max_pool_with_argmax.
       Args:
           updates:   max pooled output tensor
           mask:      argmax indices
           ksize:     ksize is the same as for the pool
       Return:
           unpool:    unpooling tensor
    """
    with tf.variable_scope(scope):
        input_shape = pool.get_shape().as_list()
        output_shape = (input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3])
        pool_ = tf.reshape(pool, [input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3]])
        batch_range = tf.reshape(tf.range(output_shape[0], dtype=ind.dtype), shape=[input_shape[0], 1, 1, 1])
        b = tf.ones_like(ind) * batch_range
        b = tf.reshape(b, [input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3], 1])
        ind_ = tf.reshape(ind, [input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3], 1])
        ind_ = tf.concat(1, [b, ind_])
        ref = tf.Variable(tf.zeros([output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]]))
        ret = tf.scatter_nd_update(ref, ind_, pool_)
        ret = tf.reshape(ret, [output_shape[0], output_shape[1], output_shape[2], output_shape[3]])
        return ret
chahld commented 7 years ago

I've adapted Pepslee's version to use tf.scatter_nd, instead of tf.scatter_nd_update to avoid creation of a Variable. The Variable was causing problems to checkpoint files because of fixed batch size, so if you run with a different batch size it wasn't able to read the checkpoint.

def unpool(pool, ind, ksize=[1, 2, 2, 1], scope='unpool'):
    """
       Unpooling layer after max_pool_with_argmax.
       Args:
           pool:   max pooled output tensor
           ind:      argmax indices
           ksize:     ksize is the same as for the pool
       Return:
           unpool:    unpooling tensor
    """
    with tf.variable_scope(scope):
        input_shape = pool.get_shape().as_list()
        output_shape = (input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3])

        flat_input_size = np.prod(input_shape)
        flat_output_shape = [output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]]

        pool_ = tf.reshape(pool, [flat_input_size])
        batch_range = tf.reshape(tf.range(output_shape[0], dtype=ind.dtype), shape=[input_shape[0], 1, 1, 1])
        b = tf.ones_like(ind) * batch_range
        b = tf.reshape(b, [flat_input_size, 1])
        ind_ = tf.reshape(ind, [flat_input_size, 1])
        ind_ = tf.concat([b, ind_], 1)

        ret = tf.scatter_nd(ind_, pool_, shape=flat_output_shape)
        ret = tf.reshape(ret, output_shape)
        return ret
ThomasWollmann commented 7 years ago

I've adapted chahld's version to handle unknown input tensor shape.

def unpool(pool, ind, ksize=[1, 2, 2, 1], scope='unpool'):
    """
       Unpooling layer after max_pool_with_argmax.
       Args:
           pool:   max pooled output tensor
           ind:      argmax indices
           ksize:     ksize is the same as for the pool
       Return:
           unpool:    unpooling tensor
    """
    with tf.variable_scope(scope):
        input_shape =  tf.shape(pool)
        output_shape = [input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3]]

        flat_input_size = tf.cumprod(input_shape)[-1]
        flat_output_shape = tf.stack([output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]])

        pool_ = tf.reshape(pool, tf.stack([flat_input_size]))
        batch_range = tf.reshape(tf.range(tf.cast(output_shape[0], tf.int64), dtype=ind.dtype), 
                                          shape=tf.stack([input_shape[0], 1, 1, 1]))
        b = tf.ones_like(ind) * batch_range
        b = tf.reshape(b, tf.stack([flat_input_size, 1]))
        ind_ = tf.reshape(ind, tf.stack([flat_input_size, 1]))
        ind_ = tf.concat([b, ind_], 1)

        ret = tf.scatter_nd(ind_, pool_, shape=tf.cast(flat_output_shape, tf.int64))
        ret = tf.reshape(ret, tf.stack(output_shape))
        return ret
MilesZhao commented 7 years ago

Do we have an alternative for MaxPoolWithArgmax on CPUs?

isaacgerg commented 7 years ago

I am also looking to implement swwae. It looks like a lot of good work has been done to address the unpooling issues. Does anyone have a small example showing it all put together?

yselivonchyk commented 7 years ago

@isaacgerg, I have next example: https://github.com/yselivonchyk/Tensorflow_WhatWhereAutoencoder

It might use a bit older version of the code, but illustrates the pipeline. Note, that it might be different from official TF API released with the latest version.

isaacgerg commented 7 years ago

@yselivonchyk Quick question. Figure 1 b of the paper shows 2 "where"s going to the decoder. It looks like you only have one. Is this correct?

yselivonchyk commented 7 years ago

@isaacgerg correct. I reproduced the experiment there which utilizes only a single layer.

If you want to, you can take a look at https://github.com/yselivonchyk/TensorFlow_DCIGN/blob/master/model_interpreter.py which can build a model, including SWWAE using syntax like in original paper '(16)5c-(32)3c-Xp'

isaacgerg commented 7 years ago

@yselivonchyk Thanks. By inverse graphics, are you referring to the phrase Hinton uses to describe what he believes the brain does?

yselivonchyk commented 7 years ago

@isaacgerg, never thought of that. Original intention was to reproduce https://arxiv.org/abs/1503.03167 therefore the name. But it all adds up, since authors worked together with Hinton. In the paper they feed a sequence of images showing a single transformation and trying to isolate info about transformation in the encoding space.

model_interpreter.py part, though, is supposed to be independent of that concept.

isaacgerg commented 7 years ago

@yselivonchyk Yes, I believe this is similar to hinton's transforming autoencoders.

MilesZhao commented 7 years ago

@isaacgerg I think we can use tf.gradients to realize this on CPUs. Firstly we do gradients of max-pooled results with respect to feature maps, which helps us find the locations of maximum values. Secondly, according to the locations, we can do up-pooling stuff, a.k.a re-construct something. This is a little tricky. In our case, we only have a maximum value for each feature map. SO, the "where it is" information is straightforward and we can use this to re-construct partial image.

isaacgerg commented 7 years ago

@MilesZhao This makes sense. I am trying to translate this to tf code.

MilesZhao commented 7 years ago

@isaacgerg I once emailed to the author of SWWAE. They realize it in Torch. But, I believe there is some example in keras.

isaacgerg commented 7 years ago

@MilesZhao The keras example will only run in theano.

isaacgerg commented 7 years ago

I was able to get the keras example to run in tensorflow 1.2.1 with minimial changes.

manglav commented 7 years ago

@isaacgerg can you share your tensorflow example?

isaacgerg commented 7 years ago

@manglav https://github.com/isaacgerg/keras_odds_and_ends

Please message me if you find errors.

teramototoya commented 7 years ago

Hi! @ThomasWollmann ,I have a problem. Please see tensorflow/tensorflow#8102, scatter_nd has duplication problem. I think this is not work out for Zeiler Unpooling layer. But I want to use unknown input tensor shape. Is there a way to solve this problem? Thanks!

ThomasWollmann commented 7 years ago

@teramototoya I recognized this issue as well in my experiments. However, I don't have a solution yet.

ceteke commented 7 years ago

@ThomasWollmann Hey, If I'm not wrong, I see that your code only works for stride size of 1 right?

rayanelleuch commented 6 years ago

Small improvement of ThomasWollmann code to add the known shape to the output tensor (also removed the tf.stack that were not needed). Useful when using tf.contrib.layers.conv2d.

def unpool(pool, ind, ksize=[1, 2, 2, 1], scope='unpool'):
    """
       Unpooling layer after max_pool_with_argmax.
       Args:
           pool:   max pooled output tensor
           ind:      argmax indices
           ksize:     ksize is the same as for the pool
       Return:
           unpool:    unpooling tensor
    """
    with tf.variable_scope(scope):
        input_shape = tf.shape(pool)
        output_shape = [input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3]]

        flat_input_size = tf.reduce_prod(input_shape)
        flat_output_shape = [output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]]

        pool_ = tf.reshape(pool, [flat_input_size])
        batch_range = tf.reshape(tf.range(tf.cast(output_shape[0], tf.int64), dtype=ind.dtype), 
                                          shape=[input_shape[0], 1, 1, 1])
        b = tf.ones_like(ind) * batch_range
        b1 = tf.reshape(b, [flat_input_size, 1])
        ind_ = tf.reshape(ind, [flat_input_size, 1])
        ind_ = tf.concat([b1, ind_], 1)

        ret = tf.scatter_nd(ind_, pool_, shape=tf.cast(flat_output_shape, tf.int64))
        ret = tf.reshape(ret, output_shape)

        set_input_shape = pool.get_shape()
        set_output_shape = [set_input_shape[0], set_input_shape[1] * ksize[1], set_input_shape[2] * ksize[2], set_input_shape[3]]
        ret.set_shape(set_output_shape)
        return ret
tspthomas commented 6 years ago

Hello,

So the bottom line is: some of the operations are not implemented in GPU which causes Unpool to be slow in Tensorflow. Is that correct?

I've tried to run some benchmarks using Tensorflow's default benchmarks and for the previous version (from May 13), the top offender is FloorDiv. I've changed to the above version (@rayanelleuch's) and the problem is now with ConcatV2 function.

Pepslee commented 6 years ago

new implementation, tf.one_hot has GPU implementation, but i check only forward computation, and i`m not sure, that backward gradient is implemented for this operation

def unpool(pool, ind, ksize=(1, 2, 2, 1), scope='unpool'):
    """
       Unpooling layer after max_pool_with_argmax.
       Args:
           pool:   max pooled output tensor
           ind:      argmax indices (produced by tf.nn.max_pool_with_argmax)
           ksize:     ksize is the same as for the pool
       Return:
           unpooled:    unpooling tensor
    """
    with tf.variable_scope(scope):
        pooled_shape = pool.get_shape().as_list()

        flatten_ind = tf.reshape(ind, (pooled_shape[0], pooled_shape[1] * pooled_shape[2] * pooled_shape[3]))
        # sparse indices to dense ones_like matrics
        one_hot_ind = tf.one_hot(flatten_ind,  pooled_shape[1] * ksize[1] * pooled_shape[2] * ksize[2] * pooled_shape[3], on_value=1., off_value=0., axis=-1)
        one_hot_ind = tf.reduce_sum(one_hot_ind, axis=1)
        one_like_mask = tf.reshape(one_hot_ind, (pooled_shape[0], pooled_shape[1] * ksize[1], pooled_shape[2] * ksize[2], pooled_shape[3]))
        # resize input array to the output size by nearest neighbor
        img = tf.image.resize_nearest_neighbor(pool, [pooled_shape[1] * ksize[1], pooled_shape[2] * ksize[2]])
        unpooled = tf.multiply(img, tf.cast(one_like_mask, img.dtype))
        return unpooled