tensorflow / java

Java bindings for TensorFlow
Apache License 2.0
813 stars 200 forks source link

raw bytes from TString for tensor serialization #471

Closed lucaro closed 2 years ago

lucaro commented 2 years ago

I'm trying to write tfrecords from a Java application, but when I read back the String from the tensor serialization function and convert it to bytes in order to be written, I get differences in certain places. This has probably something to do with the character encoding in Java. I added some example code below. Is there another way to get to the raw bytes of a TString in order to get around this?


import com.google.protobuf.ByteString;
import org.tensorflow.ConcreteFunction;
import org.tensorflow.Signature;
import org.tensorflow.hadoop.util.TFRecordWriter;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.DataBuffer;
import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.io.SerializeTensor;
import org.tensorflow.proto.example.BytesList;
import org.tensorflow.proto.example.Example;
import org.tensorflow.proto.example.Feature;
import org.tensorflow.proto.example.Features;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TString;

import java.io.*;
import java.nio.charset.StandardCharsets;

public class Main {

    public static void main(String[] args) throws IOException {

        float[] vecFloat = new float[512];

        for (int i = 0; i < vecFloat.length; i++) {
            vecFloat[i] = (float) i;
        }

        TFloat32 tensor = TFloat32.tensorOf(Shape.of(1, 512), DataBuffers.of(vecFloat));

        TFloat32 empty = TFloat32.tensorOf(Shape.of(1, 512), DataBuffers.of(new float[512]));

        //expecting 2064 elements starting with 8, 1, 18, 9, 18, 2, 8, 1, 18, 3, 8, 128, 4, 34, 128, 16, 0, 0, 0, 0, 0, 0, 128, 63, 0, 0, 0, 64, 0, ...
        //got       2064 elements starting with 8, 1, 18, 9, 18, 2, 8, 1, 18, 3, 8,  63, 4, 34,  63, 16, 0, 0, 0, 0, 0, 0,  63, 63, 0, 0, 0, 64, 0, ...
        //                                                                          ^^^         ^^^                        ^^^  **
        byte[] tensorBytes = serialize(tensor);

        //expecting 2064 elements starting with 8, 1, 18, 9, 18, 2, 8, 1, 18, 3, 8, 128, 4, 34, 128, 16, 0, 0, 0, 0, 0, 0, 0, ...
        //got       2064 elements starting with 8, 1, 18, 9, 18, 2, 8, 1, 18, 3, 8,  63, 4, 34,  63, 16, 0, 0, 0, 0, 0, 0, 0, ...
        //                                                                          ^^^         ^^^
        byte[] emptyBytes = serialize(empty);

        Example example = Example.newBuilder().setFeatures(
                Features.newBuilder()
                        .putFeature("name", feature("testRecord".getBytes()))
                        .putFeature("vector", feature(tensorBytes))
        ).build();

        DataOutputStream stream = new DataOutputStream(new FileOutputStream("test.tfrecord"));
        TFRecordWriter writer = new TFRecordWriter(stream);

        writer.write(example.toByteArray());

        stream.flush();
        stream.close();

    }

    private static byte[] serialize(TFloat32 tensor) {

        try (
                ConcreteFunction fun = ConcreteFunction.create(Main::serializeTensor);
                TString serialized = (TString) fun.call(tensor);
        ) {
            DataBuffer<String> buf = DataBuffers.ofObjects(String.class, 1);
            serialized.read(buf);
            serialized.close();
            return buf.getObject(0).getBytes(StandardCharsets.US_ASCII);
        }

    }

    private static Signature serializeTensor(Ops tf) {
        Placeholder<TFloat32> x = tf.placeholder(TFloat32.class);
        SerializeTensor op = tf.io.serializeTensor(x);
        return Signature.builder().input("in", x).output("out", op).build();
    }

    private static BytesList bytesList(byte[] bytes) {
        return BytesList.newBuilder().addValue(ByteString.copyFrom(bytes)).build();
    }

    private static Feature feature(byte[] bytes) {
        return Feature.newBuilder().setBytesList(bytesList(bytes)).build();
    }

}
lucaro commented 2 years ago

Never mind, it works using


private static byte[] serialize(TFloat32 tensor) {
        try (
                ConcreteFunction fun = ConcreteFunction.create(Main::serializeTensor);
                TString serialized = (TString) fun.call(tensor);
        ) {
            DataBuffer<byte[]> buf = DataBuffers.ofObjects(byte[].class, 1);
            serialized.asBytes().read(buf);
            serialized.close();
            return buf.getObject(0);
        }
    }