Kotlin / kotlinx.serialization

Kotlin multiplatform / multi-format serialization
Apache License 2.0
5.21k stars 614 forks source link

[help/feature] Streaming multi collection size serializer #2694

Open Chuckame opened 1 month ago

Chuckame commented 1 month ago

Sorry for the bad title, it's quite difficult to sum-up 😞

I need to implement the array serialization for avro, but it works differently than usual encodings.

A collection (arrays & maps) is serialized as blocks, where each block starts with the size of the collection (an int). When a a size is 0, then the collection is finished. Here a more visual explanation:

<1st block items count> | ... items ... | <n block items count> | ... items ... | 0

So just one block would be serialized like this:

<items count> | ... items ... | 0

Encoding is not an issue as we can make chunks quite easily.

But decoding is harder:

I also tried to change the behavior inside decodeSerializableValue but T.collectionSize() is not accessible as it is protected. All the possible implementations of AbstractCollectionSerializer are also internal so I'm not able of getting the real type like HashMap or ArrayList to be able of getting the collection size.

Here is the "wanted" code:

    @OptIn(InternalSerializationApi::class)
    override fun <T> decodeSerializableValue(deserializer: DeserializationStrategy<T>): T {
        if (deserializer is AbstractCollectionSerializer<*, T, *>) {
            var result: T = deserializer.merge(this, null)
            with(deserializer) {
                if (result.collectionSize() > 0) {
                    var prevSize = result.collectionSize()
                    while (true) {
                        result = deserializer.merge(this@AbstractAvroDirectDecoder, result)
                        val newSize = result.collectionSize()
                        if (prevSize == newSize) {
                            break
                        }
                        prevSize = newSize
                    }
                }
            }
            return result
        }

        return super<AbstractDecoder>.decodeSerializableValue(deserializer)
    }

Currently, I check the type of result to get its size properly, but this become hard to maintain:

    fun interface SizeGetter<T> {
        fun T.collectionSize(): Int
    }

    private fun <T> T.collectionSizeGetter(): SizeGetter<T> {
        return when (this) {
            is Collection<*> -> SizeGetter { size }
            is Map<*, *> -> SizeGetter { size }
            is Array<*> -> SizeGetter { size }
            is BooleanArray -> SizeGetter { size }
            is ByteArray -> SizeGetter { size }
            is ShortArray -> SizeGetter { size }
            is IntArray -> SizeGetter { size }
            is LongArray -> SizeGetter { size }
            is FloatArray -> SizeGetter { size }
            is DoubleArray -> SizeGetter { size }
            is CharArray -> SizeGetter { size }
            else -> throw SerializationException("Unsupported collection type: ${this?.let { it::class }}")
        }
    }

Proposal / Ideas

pdvrieze commented 1 month ago

@Chuckame The way you would normally implement something like this in a format would be to use a specialised decoder for the collection. This decoder would then record the item counts for the blocks (and special case of the empty block). You can record these counts when the collection serializer requests this information from the format.

Chuckame commented 1 month ago

Sorry @pdvrieze I don't really understand. Do you have an example to provide ?

Currently I'm using the provided code in my original post, where you can see that I'm overriding decodeSerializableValue of the Decoder. I cannot see any other way to decode blocks as we need explicit calls of decodeXxElement from the serializer to put the decoded values inside the collection (Except maybe using a custom serializer as said in the original post, but I'm missing some internal APIs)

pdvrieze commented 1 month ago

@Chuckame Looking back at your original post, I guess that the main challenge you have is that:

  1. You can't know the actual full list size without reading all elements.
  2. You want to use readAll repeatedly
  3. This should work as a single value, not as repeated entries in a composite/structure.
  4. You don't want to just capture the element deserializer (there are hacks you could use to capture this) and deal with this manually
  5. This should work with different kinds of elements, significantly including primitives (that could be parsed efficiently)

The issues you encounter are:

  1. AbstractCollectionSerializer.merge is designed to be called a single time to read an entire collection (note that it will call beginStructure and endStructure to delineate this).
  2. AbstractCollectionSerializer.merge can build a list from multiple parts, but that is intended to be as part of a composite value (challenge 4).
  3. You don't want to parse the entire list "locally" first before providing it to the collection serializer

As to the solution, what you want to do is write the collection implementation of decodeSerializableValue to pretend it is actually using a composite deserializer that flattens a collection of collections (you kind of have this already). You also need a way to detect the end of this list. So what you do is to have a new decoder (all the boring bits left out):

internal class ListSizeDecoder(val delegate: Decoder): Decoder, CompositeDecoder {
    // only include "interesting bits" in the example -> most is delegated to the `delegate`
    var lastListSize = -1

    var compositeDelegate: CompositeDecoder?

    override fun beginStructure() {
         compositeDelegate = delegate.beginStructure()
         return compositeDelegate // in endStructure you want to set it to null
    }

    override fun decodeCollectionSize(descriptor: SerialDescriptor) {
        lastListSize = compositeDelegate.decodeCollectionSize(descriptor)
        return lastListSize
    }
}

Using this decoder as the first parameter when calling merge, you have now captured the block size. You can use this to determine whether to stop (the value was 0).

pdvrieze commented 1 month ago

For primitives you may want to have a special case (use the serialName of the collection element - here is where unique serialNames come through), in such case you may want to just bulk read into a pre-allocated array. This only works with the built-in serializers though as serializers are allowed to do all kinds of weird stuff (a long wire value can actually be a dateTime).

Chuckame commented 1 month ago

After reading multiple times, I think I did not get how in your example it will read multiple blocks :/

You want to use readAll repeatedly

I cannot as it is protected. Also, readAll is not calling decodeCollectionSize while merge call it. I don't see how I can bypass this merge method.


By the way, after decompiling, I can see how is deserialized a list item as a very good entrypoint for this need:

      @NotNull
      public Clients deserialize(@NotNull Decoder decoder) {
         Intrinsics.checkNotNullParameter(decoder, "decoder");
         SerialDescriptor var2 = this.getDescriptor();
         boolean var3 = true;
         boolean var4 = false;
         int var5 = 0;
         List var6 = null;
         CompositeDecoder var7 = decoder.beginStructure(var2);
         KSerializer[] var8 = Clients.$childSerializers;
         if (var7.decodeSequentially()) {
            var6 = (List)var7.decodeSerializableElement(var2, 0, (DeserializationStrategy)var8[0], var6);
            var5 |= 1;
         } else {
            while(var3) {
               int var9 = var7.decodeElementIndex(var2);
               switch (var9) {
                  case -1:
                     var3 = false;
                     break;
                  case 0:
                     var6 = (List)var7.decodeSerializableElement(var2, 0, (DeserializationStrategy)var8[0], var6);
                     var5 |= 1;
                     break;
                  default:
                     throw new UnknownFieldException(var9);
               }
            }
         }

         var7.endStructure(var2);
         return new Clients(var5, var6, (SerializationConstructorMarker)null);
      }
pdvrieze commented 1 month ago

After reading multiple times, I think I did not get how in your example it will read multiple blocks :/

You want to use readAll repeatedly

Actually, for your example code, if you create the Decoder/CompositeDecoder I suggested (make sure you implement beginStructure and endStructure in a way that reflects the actual wire-data) , and pass it along as the first parameter to merge you should get the expected behaviour (you can reuse the decoder, and read the state from it).

Chuckame commented 1 month ago

So I'll need 2 implementations:

Is it what you meant ? I'll try it

pdvrieze commented 1 month ago

So I'll need 2 implementations:

  • one from decodeSerializableValue that iterates multiple times until I read an empty array block
  • one wrapping the decoder to intercept the decodeCollectionSize to allow the first implementation in decodeSerializableValue to read the collection size

Is it what you meant ? I'll try it

Yes. You need to create a specific decoder for lists. Note also that this may work differently with beginStructure/endStructure as you may have markers for regular structs that differ from what is used for lists - I don't know the specifics of your datastructure.