def pack_into_tensor(array, axis):
"""
packs a given TensorArray into a tensor along a given axis
Parameters:
----------
array: TensorArray
the tensor array to pack
axis: int
the axis to pack the array along
Returns: Tensor
the packed tensor
"""
packed_tensor = array.pack()
shape = packed_tensor.get_shape()
rank = len(shape)
dim_permutation = [axis] + range(1, axis) + [0] + range(axis + 1, rank)
correct_shape_tensor = tf.transpose(packed_tensor, dim_permutation)
return correct_shape_tensor
Imagine I want to stack a tensorArray with 6 elements of shape [3,4,5] into a tensor of shape [3,4,5,6]. Axis is 3. After array.pack(), the shape of tensor is [6,3,4,5]. dim_permutation is [3] + [1,2] + [0] + [] = [3,1,2,0]. After transpose, the output shape is [5,3,4,6], which is wrong.
Imagine I want to stack a tensorArray with 6 elements of shape [3,4,5] into a tensor of shape [3,4,5,6]. Axis is 3. After array.pack(), the shape of tensor is [6,3,4,5]. dim_permutation is [3] + [1,2] + [0] + [] = [3,1,2,0]. After transpose, the output shape is [5,3,4,6], which is wrong.
The correct formula should be
With this formula, the dim_permutation is [1,2,3] + [0] + [] = [1,2,3,0].
Similarly, the formula in unpack_into_tensorarray is also wrong. The correct code should be
Please take a look. Thanks