google / flatbuffers

FlatBuffers: Memory Efficient Serialization Library
https://flatbuffers.dev/
Apache License 2.0
22.52k stars 3.19k forks source link

Faster python serialization #8320

Closed mhs4670go closed 1 month ago

mhs4670go commented 1 month ago

Hello. I'm trying to pack the below buffers with large data. But it takes too much time.

My flatbuffers version is 24.3.25 and I just use flatc --python --gen-object-api # .. command.

https://github.com/tensorflow/tensorflow/blob/9076ac1496dfbf228220bf728385db6c96447fdf/tensorflow/lite/schema/schema.fbs#L1573

table Buffer {
  data:[ubyte] (force_align: 16);
}

I tried two approaches. Firstly, I just assigned bytes to the buffer.data. But, then, it took so long at builder.PrependUint8(self.data[i]) line below. So, I tried to assign numpy array of bytes. It took less time than the former but converting bytes to numpy array whose dtype is uint8 takes quite long time as well. Is there any good idea for the faster serialization?


### generated codes
# BufferT
    def Pack(self, builder):
        if self.data is not None:
            if np is not None and type(self.data) is np.ndarray:
                data = builder.CreateNumpyVector(self.data)
            else:
                StartDataVector(builder, len(self.data))
                for i in reversed(range(len(self.data))):
                    builder.PrependUint8(self.data[i]) ## took so long
                data = builder.EndVector()
        Start(builder)
        if self.data is not None:
            AddData(builder, data)
        AddOffset(builder, self.offset)
        AddSize(builder, self.size)
        buffer = End(builder)
        return buffer

### my codes
assert isinstance(data, np.ndarray)
buffer = tflite.Buffer.BufferT()
# 1. just with bytes
# buffer.data = list(data.flatten().tobytes())
# 2. from np array
buffer.data = np.array(list(data.flatten().tobytes()), dtype=np.uint8)
buffer.Pack(builder)
mhs4670go commented 1 month ago

I found that np.view will do the trick:)

buffer.data = data.flatten().view(np.uint8)