plokhotnyuk / jsoniter-scala

Scala macros for compile-time generation of safe and ultra-fast JSON codecs + circe booster
MIT License
746 stars 100 forks source link

Support untagged ADTs #435

Open carymrobbins opened 4 years ago

carymrobbins commented 4 years ago

Is it possible to derive a codec which supports writing/reading untagged ADTs? Something like Aeson's UntaggedValue, of course with some caveat that the encodings must be disjoint. I am deserializing data from another source and don't have control over the JSON format.

I'm also open to other ideas, e.g. deriving my own codec by creating alternations of derived codecs. I'm not sure how possible this is. (I've found writing the codecs by hand to be pretty cumbersome, which is not really a problem so long as we come up with some way to make this work).

Ideal situation would be something like the following. Again, if this currently achievable in some other way I'd be fine with it.

sealed trait Foo
final case class Bar(a: String) extends Foo
final case class Baz(b: Int) extends Foo

implicit val fooCodec: JsonValueCodec[Foo] =
  JsonCodecMaker.make(CodecMakerConfig.withUntaggedAdt)

writeToString(Bar("yo"))
// {"a":"yo"}
writeToString(Baz(12))
// {"b":12}
readFromString[Foo](""" {"a":"yo"} """)
// Bar("yo")
readFromString[Foo](""" {"b":12} """)
// Baz(12)
plokhotnyuk commented 4 years ago

@carymrobbins

Hello, Cary! Thanks for reaching out!

Currently, it is achievable only with manually written custom codecs or 3rd-party derivation.

Let's start from your simplified example. Codecs for Bar and Baz are derived, but for Foo we can write a custom one:

implicit val barCodec: JsonValueCodec[Bar] = JsonCodecMaker.make(CodecMakerConfig)
implicit val bazCodec: JsonValueCodec[Baz] = JsonCodecMaker.make(CodecMakerConfig)
implicit val fooCodec: JsonValueCodec[Foo] = new JsonValueCodec[Foo] {
  override def decodeValue(in:  JsonReader, default:  Foo): Foo = {
    in.setMark()
    if (in.isNextToken('{')) {
      val l = in.readKeyAsCharBuf()
      if (in.isCharBufEqualsTo(l, "a")) {
        in.rollbackToMark()
        barCodec.decodeValue(in, barCodec.nullValue)
      } else if (in.isCharBufEqualsTo(l, "b")) {
        in.rollbackToMark()
        bazCodec.decodeValue(in, bazCodec.nullValue)
      } else in.unexpectedKeyError(l)
    } else in.readNullOrTokenError(default, '{')
  }
  override def encodeValue(x:  Foo, out: JsonWriter): Unit = x match {
    case x: Bar => barCodec.encodeValue(x, out)
    case x: Baz => bazCodec.encodeValue(x, out)
    case null => out.writeNull()
  }
  override val nullValue: Foo = null
}

println(writeToString[Foo](Bar("yo")))
println(writeToString[Foo](Baz(12)))
println(readFromString[Foo](""" {"a":"yo"} """))
println(readFromString[Foo](""" {"b":12} """))

Output of this code should be:

{"a":"yo"}
{"b":12}
Bar(yo)
Baz(12)

This solution can be easy evolved if all sub-types have unique field names.

In case if some simple intersection is possible you can parse all keys accumulating some bits (one per unique name of the required field) and skipping paired values out. And then after reaching JSON object end (} character) do rollback to marked position and switch to parsing of detected type:

sealed trait Foo
final case class Bar(a: String, x: Option[String]) extends Foo
final case class Baz(y: Option[Int], b: Int) extends Foo
final case class Qux(a: Int, z: Seq[Double], b: String) extends Foo

implicit val barCodec: JsonValueCodec[Bar] = JsonCodecMaker.make(CodecMakerConfig)
implicit val bazCodec: JsonValueCodec[Baz] = JsonCodecMaker.make(CodecMakerConfig)
implicit val quxCodec: JsonValueCodec[Qux] = JsonCodecMaker.make(CodecMakerConfig)
implicit val fooCodec: JsonValueCodec[Foo] = new JsonValueCodec[Foo] {
  override def decodeValue(in:  JsonReader, default:  Foo): Foo = {
    in.setMark()
    if (in.isNextToken('{')) {
      var p0 = 3
      do {
        val l = in.readKeyAsCharBuf()
        if (in.isCharBufEqualsTo(l, "a")) {
          if ((p0 & 1) != 0) p0 ^= 1
          else in.duplicatedKeyError(l)
        } else if (in.isCharBufEqualsTo(l, "b")) {
          if ((p0 & 2) != 0) p0 ^= 2
          else in.duplicatedKeyError(l)
        }
        in.skip()
      } while (in.isNextToken(','))
      in.rollbackToMark()
      p0 match {
        case 0 => quxCodec.decodeValue(in, quxCodec.nullValue)
        case 1 => bazCodec.decodeValue(in, bazCodec.nullValue)
        case 2 => barCodec.decodeValue(in, barCodec.nullValue)
        case _ => in.decodeError("missing required field(s)")
      }
    } else in.readNullOrTokenError(default, '{')
  }
  override def encodeValue(x:  Foo, out: JsonWriter): Unit = x match {
    case x: Bar => barCodec.encodeValue(x, out)
    case x: Baz => bazCodec.encodeValue(x, out)
    case x: Qux => quxCodec.encodeValue(x, out)
    case null => out.writeNull()
  }
  override val nullValue: Foo = null
}

println(writeToString[Foo](Bar("yo", None)))
println(writeToString[Foo](Baz(None, 12)))
println(writeToString[Foo](Qux(12, Seq(), "yo")))
println(writeToString[Foo](Bar("yo", Some("lo"))))
println(writeToString[Foo](Baz(Some(42), 12)))
println(writeToString[Foo](Qux(12, Seq(1.0, 2.0), "yo")))
println(readFromString[Foo](""" {"a":"yo"} """))
println(readFromString[Foo](""" {"b":12} """))
println(readFromString[Foo](""" {"a":12,"b":"yo"} """))
println(readFromString[Foo](""" {"a":"yo","x":"lo"} """))
println(readFromString[Foo](""" {"y":42,"b":12} """))
println(readFromString[Foo](""" {"a":12,"z":[1.0,2.0],"b":"yo"} """))

Expected output is:

{"a":"yo"}
{"b":12}
{"a":12,"b":"yo"}
{"a":"yo","x":"lo"}
{"y":42,"b":12}
{"a":12,"z":[1.0,2.0],"b":"yo"}
Bar(yo,None)
Baz(None,12)
Qux(12,List(),yo)
Bar(yo,Some(lo))
Baz(Some(42),12)
Qux(12,List(1.0, 2.0),yo)

This solution will work fine while there are no matching of keys for optional fields and fields with collection/array types which are optional by default, and if the set of required keys is different for each sub-type. Some of those matches can be overcome by trying to distinguish them by parsing of values and handing out errors instead of skipping.

If number of keys is greater than 8 for more efficiency the matching by the key hash code can be used (with subsequent resolving of hash collision cases by exact comparison).

BTW, you can see code for derived codes by turning on -Xmacro-settings:print-codecs option for scalac. Or you can just open the following link to see build logs with code for codecs that are used in benchmarks, including ADTs: https://plokhotnyuk.github.io/jsoniter-scala/openjdk-13.txt

So, implementation of withUntaggedAdt option for automatic derivation looks possible, but I'm not sure if it can be easily done for all supported data types securely and efficiently.

carymrobbins commented 4 years ago

I did something similar, but not quite as efficient -

https://gist.github.com/carymrobbins/d0b900257cadb458b3de9b1b532cb2b3

I had played with a few approaches and this one seems to work best. However, I'm not sure exactly how the mark stuff works and if I did any of this correctly.