sandialabs / pyttb

Python Tensor Toolbox
https://pyttb.readthedocs.io
BSD 2-Clause "Simplified" License
26 stars 13 forks source link

Pretty print function for dense tensors #350

Open tgkolda opened 5 days ago

tgkolda commented 5 days ago

Would it be possible to add a pretty print function for dense tensors. Here is what I have in mind:

def pretty_print_tensor(X, fmt="10.4f", name="Slice"):
    if not isinstance(X, ttb.tensor):
        raise ValueError("Input must be a pyttb tensor")

    # Get the shape of the tensor
    shape = X.shape

    if name == "Slice":
        print("Tensor is of shape "+ " x ".join(map(str, shape)))
    else:
        print(f"{name} is a tensor of shape " + " x ".join(map(str, shape)))

    # Iterate over all possible slices (in Fortran order)
    for index in np.ndindex(shape[2:][::-1]): # Skip the first two dimensions and reverse the order
        index = index[::-1] # Reverse the order back to the original
        # Construct the slice indices
        slice_indices = (slice(None), slice(None)) + index
        slice_data = X[slice_indices]
        print(f"{name}(:, :, {', '.join(map(str, index))}) =")
        array = slice_data.data
        for row in array:
            print(" ".join(f"{val:{fmt}}" for val in row))

So pretty_print_tensor(X,fmt="2d",name="X") produces output like this:

X is a tensor of shape 3 x 3 x 2
X(:, :, 0) =
 3  9  1
 8  2  1
 4  3  9
X(:, :, 1) =
 6  9  5
 5  6  4
 1  4  1

Or pretty_print_tensor(X,fmt="5.1f") for a different tensor produces:

Tensor is of shape 3 x 4 x 3 x 2
Slice(:, :, 0, 0) =
  1.0   7.0   5.0   5.0
  8.0   9.0   1.0   7.0
  4.0   5.0   3.0   8.0
Slice(:, :, 1, 0) =
  4.0   9.0   9.0   9.0
  1.0   2.0   1.0   3.0
  3.0   5.0   6.0   5.0
Slice(:, :, 2, 0) =
  9.0   7.0   2.0   5.0
  2.0   7.0   5.0   4.0
  5.0   5.0   4.0   8.0
Slice(:, :, 0, 1) =
  7.0   3.0   6.0   7.0
  3.0   5.0   4.0   4.0
  6.0   4.0   5.0   9.0
Slice(:, :, 1, 1) =
  2.0   4.0   4.0   7.0
  7.0   6.0   1.0   5.0
  4.0   5.0   1.0   7.0
Slice(:, :, 2, 1) =
  9.0   3.0   9.0   5.0
  7.0   6.0   9.0   8.0
  1.0   3.0   8.0   2.0
tgkolda commented 5 days ago

Maybe it's a function in the tensor class, in which case the call might be X.pretty_print(...)