tensorflow / java

Java bindings for TensorFlow
Apache License 2.0
813 stars 200 forks source link

Ability to convert Tensor to String representation #268

Open cowwoc opened 3 years ago

cowwoc commented 3 years ago

Per our discussion on Gitter, here is a possible implementation for converting Tensors to a String representation. It is still missing some important features, like collapsing long arrays using ellipses, but this can serve as a stepping stone. The functionality is meant to ease troubleshooting/debugging so performance should not be an issue.

import org.tensorflow.Session;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.DataBuffer;
import org.tensorflow.ndarray.buffer.DoubleDataBuffer;
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
import org.tensorflow.ndarray.buffer.IntDataBuffer;
import org.tensorflow.ndarray.buffer.LongDataBuffer;
import org.tensorflow.ndarray.buffer.ShortDataBuffer;
import org.tensorflow.types.TFloat16;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TFloat64;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.TUint8;

import java.util.StringJoiner;

public final class Tensors
{
    private final Session session;

    /**
     * @param session the session used by all operations
     */
    public Tensors(Session session)
    {
        this.session = session;
    }

    /**
     * @param tensor a tensor
     * @return the String representation of the tensor
     */
    public String toString(TFloat64 tensor)
    {
        Shape shape = tensor.shape();
        DoubleDataBuffer doubles = tensor.asRawTensor().data().asDoubles();
        return toString(doubles, shape, 0, 0, tensor.rank()).text;
    }

    /**
     * @param tensor a tensor
     * @return the String representation of the tensor
     */
    public String toString(TFloat32 tensor)
    {
        Shape shape = tensor.shape();
        FloatDataBuffer doubles = tensor.asRawTensor().data().asFloats();
        return toString(doubles, shape, 0, 0, tensor.rank()).text;
    }

    /**
     * @param tensor a tensor
     * @return the String representation of the tensor
     */
    public String toString(TFloat16 tensor)
    {
        Shape shape = tensor.shape();
        FloatDataBuffer doubles = tensor.asRawTensor().data().asFloats();
        return toString(doubles, shape, 0, 0, tensor.rank()).text;
    }

    /**
     * @param tensor a tensor
     * @return the String representation of the tensor
     */
    public String toString(TInt64 tensor)
    {
        Shape shape = tensor.shape();
        LongDataBuffer doubles = tensor.asRawTensor().data().asLongs();
        return toString(doubles, shape, 0, 0, tensor.rank()).text;
    }

    /**
     * @param tensor a tensor
     * @return the String representation of the tensor
     */
    public String toString(TInt32 tensor)
    {
        Shape shape = tensor.shape();
        IntDataBuffer doubles = tensor.asRawTensor().data().asInts();
        return toString(doubles, shape, 0, 0, tensor.rank()).text;
    }

    /**
     * @param tensor a tensor
     * @return the String representation of the tensor
     */
    public String toString(TUint8 tensor)
    {
        Shape shape = tensor.shape();
        ShortDataBuffer doubles = tensor.asRawTensor().data().asShorts();
        return toString(doubles, shape, 0, 0, tensor.rank()).text;
    }

    /**
     * @param data      the data
     * @param shape     the shape of the tensor
     * @param index     the index of the tensor element to start at
     * @param dimension the current dimension
     * @param rank      the maximum dimension
     * @return the String representation of the {@code dimension}
     */
    private ToStringResponse toString(DataBuffer<?> data, Shape shape, int index, int dimension, int rank)
    {
        int numElements = 0;
        StringJoiner joiner;
        if (dimension < rank)
        {
            joiner = new StringJoiner(",\n", "\t".repeat(dimension) + "[\n", "\n" + "\t".repeat(dimension) + "]");
            for (long i = 0, size = shape.size(rank - 1); i < size; ++i)
            {
                ToStringResponse response = toString(data, shape, index, dimension + 1, rank);
                joiner.add(response.text);
                numElements += response.numElements;
                index += response.numElements;
            }
        }
        else
        {
            joiner = new StringJoiner(",", "\t".repeat(dimension) + "[", "]");
            for (long i = 0, size = shape.size(rank - 1); i < size; ++i)
            {
                joiner.add(String.valueOf(data.getObject(index)));
                ++numElements;
                ++index;
            }
        }
        return new ToStringResponse(joiner.toString(), numElements);
    }

    /**
     * @param text        the string representation of a tensor dimension
     * @param numElements the number of elements contained in {@code text}
     */
    private record ToStringResponse(String text, int numElements)
    {
    }
}
rnett commented 3 years ago

Looks good, are you planning on making a PR for this?

If so, some initial comments:

cowwoc commented 3 years ago

@rnett Good suggestions. I'll try to formulate a PR.

Question though, since this method is meant strictly for debugging, couldn't we implement it for non-eager sessions as well? We could spin up a graph runner, evaluate the operand, and return the String representation.

rnett commented 3 years ago

You could, but if that tensor depends on anything non-constant (i.e. placeholders or variables), you won't be able to get it, since the session has no way of knowing about those inputs. And I would think most things you want to debug would have dependencies like that. Plus I'm not sure sessions support adding things to the graph after the session is created, and you'd have to re-run the whole graph each time you called asString.

Once we finish functions and eager gradients, most debugging should be done in eager mode anyways. You'd almost always use functions instead of graphs, and there would be a global "execute functions in eager mode" like in Python. It's not necessarily impossible to have asString in graph mode, but it's not easy and won't ever fit very well, and since we have this coming I don't think it's worth it. Feel free to come up with an implementation and make a PR, but maybe PR just the eager version first.