disneystreaming / smithy4s

https://disneystreaming.github.io/smithy4s/
Other
339 stars 68 forks source link

`DynamicSchemaIndex.loadModel` is slow #1258

Closed kubukoz closed 3 months ago

kubukoz commented 9 months ago

On my machine (M1 Max MBP), loading a relatively large model (15k+ shapes) takes less than a second on the Smithy (aws) side but over 11 seconds on the DynamicSchemaIndexside.

You can try yourself (with granular steps) with the following scala-cli script:

//> using lib "com.disneystreaming.smithy4s::smithy4s-dynamic:0.18.2"
//> using lib "io.get-coursier:coursier_2.13:2.1.7"
//> using scala "3.3.1"
//> using toolkit "latest"
//> using option "-Wunused:imports"
import coursier._
import coursier.parse.DependencyParser
import smithy4s.Document
import smithy4s.dynamic.DynamicSchemaIndex
import smithy4s.dynamic.NodeToDocument
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.ModelSerializer
import software.amazon.smithy.model.transform.ModelTransformer
import scala.concurrent.ExecutionContext

def measure[A](tag: String)(f: => A): A = {
  val start = System.nanoTime()
  val result = f
  println(s"$tag: ${(System.nanoTime() - start) / 1000000}ms")
  result
}

@main def main = {

  val jars = measure("fetch jars") {
    Fetch()
      .withDependencies(
        List(
          "aws-quicksight-spec",
          "aws-sagemaker-spec"
        ).takeRight(3)
          .map { sdk =>
            DependencyParser
              .dependency(
                s"com.disneystreaming.smithy:$sdk:2023.09.22",
                "2.13"
              )
              .toOption
              .get
          }
      )
      .run()(ExecutionContext.global)
      // poor man's way of filtering out jars without a smithy manifest. don't judge pls
      .filterNot(_.toString().contains("smithy-model"))
      .filterNot(_.toString().contains("smithy-jmespath"))
      .filterNot(_.toString().contains("smithy-utils"))
  }

  val model = measure("load model") {
    val assembler = Model
      .assembler()

    jars.foreach(jar => assembler.addImport(jar.toURI().toURL()))

    assembler
      .assemble()
      .unwrap()
  }

  println(model.shapes().toList().size() + " shapes")

  val codec = measure("derive codec") {
    Document.decoderFromSchema[smithy4s.dynamic.model.Model]
  }

  // the equivalent of loadModel's first part
  val converted = measure("convert model") {
    val flattenedModel =
      ModelTransformer.create().flattenAndRemoveMixins(model);
    val node = ModelSerializer.builder().build.serialize(flattenedModel)
    val document = NodeToDocument(node)

    codec.decode(document).toTry.get
  }

  measure("load dsi") { DynamicSchemaIndex.load(converted) }
}

Here are the timings:

fetch jars: 359ms
load model: 624ms
15545 shapes
derive codec: 75ms
convert model: 281ms
load dsi: 11174ms

Ideas for improvement

NOTE: some models are not like the others

The characteristics of a model, e.g. how deep vs how wide the shapes are, may have an impact on the timings here. The profiles described below come from distinct models that may or may not have similar depths.

Optimizing recursiveness checks

First, I wanted to show a profiler run on a confidential model with, apparently, deep shape closures.

I ran a profiler during load dsi and it turns out that the biggest culprit, taking over 70% of the entire load call, are recursion checks:

image

Perhaps there's a more efficient way of performing such checks. I haven't seen the code, but transitiveClosure may be trying to do too much (i.e. collecting all shapes into a large set, rather than recursing until a known ShapeId is found).


Now here's the flame graph for the AWS model made from the specs in the snippet above:

image

Zoomed in:

image

isRecursive takes up 86% of the whole thing, again. I think if we can optimize that part, we can save a lot of time on loading dynamic models in general.

Extra info

I checked isRecursive for any outliers and it appears that only a dozen of shapes (out of thousands) take over 500ms to check:

Warning: took 603 ms to compile com.amazonaws.quicksight#AnalysisDefinition
Warning: took 563 ms to compile com.amazonaws.quicksight#CreateAnalysisRequest
Warning: took 588 ms to compile com.amazonaws.quicksight#CreateDashboardRequest
Warning: took 572 ms to compile com.amazonaws.quicksight#CreateTemplateRequest
Warning: took 560 ms to compile com.amazonaws.quicksight#DashboardVersionDefinition
Warning: took 520 ms to compile com.amazonaws.quicksight#DescribeAnalysisDefinitionResponse
Warning: took 518 ms to compile com.amazonaws.quicksight#DescribeDashboardDefinitionResponse
Warning: took 520 ms to compile com.amazonaws.quicksight#DescribeTemplateDefinitionResponse
Warning: took 504 ms to compile com.amazonaws.quicksight#SheetDefinition

which suggests that they're extra deep/wide, but I haven't checked that yet.

I also checked what happens if we rewrite isRecursive to be more lazy and return as soon as the current shape is seen (note: didn't check for correctness):

private def isRecursive(
    id: ShapeId
): Boolean = measureIfSlow(id.toString) {
  def transitiveClosureContainsSelf(
      _id: ShapeId,
      visited: Set[ShapeId]
  ): Boolean = {
    val neighbours = closureMap.getOrElse(_id, Set.empty)

    if (neighbours.contains(id)) true
    else {
      val newVisited = visited + _id

      neighbours.iterator
        .filterNot(newVisited)
        .exists(
          transitiveClosureContainsSelf(_, newVisited)
        )
    }
  }

  transitiveClosureContainsSelf(id, Set.empty)
}

It didn't help much:

Warning: took 622 ms to compile com.amazonaws.quicksight#AnalysisDefinition
Warning: took 575 ms to compile com.amazonaws.quicksight#CreateAnalysisRequest
Warning: took 583 ms to compile com.amazonaws.quicksight#CreateDashboardRequest
Warning: took 576 ms to compile com.amazonaws.quicksight#CreateTemplateRequest
Warning: took 576 ms to compile com.amazonaws.quicksight#DashboardVersionDefinition
Warning: took 575 ms to compile com.amazonaws.quicksight#DescribeAnalysisDefinitionResponse
Warning: took 579 ms to compile com.amazonaws.quicksight#DescribeDashboardDefinitionResponse
Warning: took 572 ms to compile com.amazonaws.quicksight#DescribeTemplateDefinitionResponse
Warning: took 562 ms to compile com.amazonaws.quicksight#SheetDefinition
Warning: took 578 ms to compile com.amazonaws.quicksight#TemplateVersionDefinition
Warning: took 574 ms to compile com.amazonaws.quicksight#UpdateAnalysisRequest
Warning: took 577 ms to compile com.amazonaws.quicksight#UpdateDashboardRequest
Warning: took 579 ms to compile com.amazonaws.quicksight#UpdateTemplateRequest
Warning: took 551 ms to compile com.amazonaws.quicksight#Visual
load dsi: 12052ms

and neither did caching or a combination of the two. I'm hoping that it's just a mistake on my side and the closure check can still be optimized.

Baccata commented 9 months ago

I think the key is gonna be to precompute a set of recursive shapes from the ClosureMap.

If you can optimise the Map[ShapeId, Set[ShapeId]] => Set[ShapeID] function, you've essentially won.

You can use the fact that if you traverse several shapes before finding a recursion chain, ALL the shapes of the chain are recursive.

You can also use the fact that once a ShapeId is proven non-recursive, there's no need to traverse it ever again.

So I think an implementation using two mutable sets, one for proven recursive, one for proven-non recursive, would be decent.

Baccata commented 9 months ago

Something like :

import scala.collection.mutable.{Set => MSet}
import scala.collection.immutable.ListSet

def recursiveShapes[A](map: Map[A, Set[A]]): Set[A] = {
  val provenRecursive: MSet[A] = MSet.empty
  val provenNotRecursive: MSet[A] = MSet.empty

  def crawl(key: A, seen: ListSet[A]): Unit = {
    if (provenNotRecursive(key)) ()
    else if (provenRecursive(key)) {
      // here we may have found a new "recursive path" between
      // elements that we know are recursive
      // ps : I'm not sure about this branch, it should probably be dropped
      provenRecursive ++= seen.dropWhile(s => !provenRecursive(s))
    } else if (seen(key)) {
      // dropping elements that don't belong to the "recursive path"
      provenRecursive ++= seen.dropWhile(_ != key)
    } else
      map.get(key) match {
        case None => ()
        case Some(values) =>
          values.foreach(crawl(_, seen + key))
          if (!provenRecursive(key)) provenNotRecursive += key
      }
  }

  map.keySet.foreach(crawl(_, ListSet.empty))
  provenRecursive.toSet
}

@main()
def test = {
  val map = Map[Int, Set[Int]](
    1 -> Set(2),
    2 -> Set(3, 4, 5),
    3 -> Set(4),
    5 -> Set(6, 7),
    6 -> Set(2),
    7 -> Set(2)
  )
  println(recursiveShapes(map))
}

PS : this need to be tested a lot more thoroughly, obviously

Baccata commented 9 months ago

If you want to go ultra optimal, Tarjan is probably the best algorithm : https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm

kubukoz commented 9 months ago

For the AWS example I tried earlier, the implementation you provided was enough to cut it down by 10x ;)

We could try Tarjan too. I'll see when I can get to this, if nobody picks it up in the meantime

kubukoz commented 9 months ago

thanks for the research!

daddykotex commented 9 months ago

I think the key is gonna be to precompute a set of recursive shapes from the ClosureMap.

If you can optimise the Map[ShapeId, Set[ShapeId]] => Set[ShapeID] function, you've essentially won.

You can use the fact that if you traverse several shapes before finding a recursion chain, ALL the shapes of the chain are recursive.

You can also use the fact that once a ShapeId is proven non-recursive, there's no need to traverse it ever again.

So I think an implementation using two mutable sets, one for proven recursive, one for proven-non recursive, would be decent.

that's impressive, and also the pseudo algorithm that you just dropped out of thin air wow