Mostafa-Samir / DNC-tensorflow

A TensorFlow implementation of DeepMind's Differential Neural Computers (DNC)
MIT License
581 stars 164 forks source link

axis permutation wrong in utility.pack_into_tensor and utility.unpack_into_tensorarray #8

Open JeffOwOSun opened 7 years ago

JeffOwOSun commented 7 years ago
def pack_into_tensor(array, axis):
    """
    packs a given TensorArray into a tensor along a given axis
    Parameters:
    ----------
    array: TensorArray
        the tensor array to pack
    axis: int
        the axis to pack the array along
    Returns: Tensor
        the packed tensor
    """

    packed_tensor = array.pack()
    shape = packed_tensor.get_shape()
    rank = len(shape)

    dim_permutation = [axis] + range(1, axis) + [0]  + range(axis + 1, rank)
    correct_shape_tensor = tf.transpose(packed_tensor, dim_permutation)

    return correct_shape_tensor

Imagine I want to stack a tensorArray with 6 elements of shape [3,4,5] into a tensor of shape [3,4,5,6]. Axis is 3. After array.pack(), the shape of tensor is [6,3,4,5]. dim_permutation is [3] + [1,2] + [0] + [] = [3,1,2,0]. After transpose, the output shape is [5,3,4,6], which is wrong.

The correct formula should be

dim_permutation = range(1, axis+1) + [0] + range(axis+1, rank)

With this formula, the dim_permutation is [1,2,3] + [0] + [] = [1,2,3,0].

Similarly, the formula in unpack_into_tensorarray is also wrong. The correct code should be

dim_permutation = [axis] + range(0, axis) + range(axis+1, rank)

Please take a look. Thanks