rjagerman / glint

Glint: High performance scala parameter server
MIT License
168 stars 62 forks source link

Spark integration, and examples #28

Open MLnick opened 8 years ago

MLnick commented 8 years ago

Hi @rjagerman, this project looks very interesting and I'd like to explore it a bit more. You mention Spark integration as a goal, has there been work done on that? What about example algorithms using this parameter server?

rjagerman commented 8 years ago

Thanks for your interest in the project! Currently the parameter server runs stand-alone outside of Spark which means you'd have to run the master and servers on a cluster as separate java processes from Spark (the scripts in the sbin folder can help with that). When spawning big matrices or vectors on the parameter server you get a BigMatrix and BigVector object back respectively. These objects are serializable and can safely be used within a spark function, thus running it with Spark is extremely easy:

val matrix = glintclient.matrix[Double](100000, 1000)
val rdd = ... // some spark RDD
rdd.foreachPartition { case partition => 
    ... // this will get executed on your spark workers
    ... // they can safely use the matrix object
    ... // e.g. matrix.pull(Array(0, 1, 2 , 3))
}

This project is part of my ongoing Master thesis. I currently have an implementation of LDA that uses this parameter server and Spark that scales up to 2TB of data and 1000s of topics on a moderate computing cluster. I will open source that implementation within the next month or so.

I will also write more extensive documentation going into more detail how to set up a cluster, some code examples and some tuning tips with regards to ExecutionContexts, timeouts, etc. Due to other deadlines regarding my thesis I haven't found the time for this yet.

MLnick commented 8 years ago

Ok, thanks. I'd like to see if I can find some time to investigate some implementations on top of this.

Do you know how it compares to https://github.com/dmlc/ps-lite in terms of performance etc? In terms of ease-of-integration with frameworks like Spark, something based on Scala/Akka seems much nicer.

rjagerman commented 8 years ago

I currently have not compared the performance to ps-lite (or any other parameter server for that matter), but it would be a very interesting thing to measure. If I had to guess I'd say that ps-lite is much faster. Especially due to the very early alpha-state and experimental nature of Glint compared to the very mature implementation of ps-lite. My main goal is indeed to have a parameter server that is very easily integrated with Spark. So far it seems to work well in my practical case (LDA with Collapsed Gibbs sampling) but some raw numbers like updates/sec, requests/sec, etc. would be very interesting. I'll keep you updated here when I have some real measurements.

Additionally Glint has no regard for fault tolerance (unlike ps-lite, which offers instantaneous failover), so if a server goes down the data is lost. This is something I definitely wish to address in the future but is at the moment outside the scope of the project.

MLnick commented 8 years ago

I've started working on some POCs for Spark integration starting with linear models for simplicity. I will see if I can do some performance testing at some point. Would you be willing to share some code samples you have for Spark integration (privately if need be)?

On Thu, 11 Feb 2016 at 18:29 Rolf Jagerman notifications@github.com wrote:

I currently have not compared the performance to ps-lite (or any other parameter server for that matter), but it would be a very interesting thing to measure. If I had to guess I'd say that ps-lite is much faster. Especially due to the very early alpha-state and experimental nature of Glint compared to the very mature implementation of ps-lite. My main goal is indeed to have a parameter server that is very easily integrated with Spark. So far it seems to work well in my practical case (LDA with Collapsed Gibbs sampling) but some raw numbers like updates/sec, requests/sec, etc. would be very interesting. I'll keep you updated here when I have some real measurements.

Additionally Glint has no regard for fault tolerance (unlike ps-lite, which offers instantaneous failover), so if a server goes down the data is lost. This is something I definitely wish to address in the future but is at the moment outside the scope of the project.

— Reply to this email directly or view it on GitHub https://github.com/rjagerman/glint/issues/28#issuecomment-182944072.

rjagerman commented 8 years ago

That sounds great! :+1:

I have just now open sourced the LDA implementation here so you could take a look at that. It is based on some state-of-the-art LDA research and the many low-level optimizations, caches and buffers make the code base a bit hard to follow. I can give you some pointers:

The construction of the count-table matrix for the collapsed Gibbs sampler happens here

val topicWordCounts = gc.matrix[Long](config.vocabularyTerms, config.topics, 2, (x,y) => CyclicPartitioner(x, y))

The solver uses Spark to map partitions of our RDD to resampled partitions (effectively doing an LDA iteration) here

rdd = rdd.mapPartitionsWithIndex { case (id, it) =>
    val s = solver(model, id)
    val partitionSamples = it.toArray
    s.fit(partitionSamples, t)
    partitionSamples.toIterator
}

The s.fit(...) function internally uses the earlier constructed matrix on the parameter servers to perform an actual iteration of the algorithm and updates the counts on the matrix accordingly. It iterates over the matrix on the parameter servers here and updates counts on the parameter server here and here.

We make extensive use of buffers for performance reasons, so this can obfuscate the code quite a bit. The locks that you see in the code act as a back pressure mechanism that limit the number of open requests to the parameter servers. This is necessary since the code is so fast that it could easily flood the parameter server by asynchronously sending requests.

I will be finished with my thesis in about 2 weeks, after which I'll have some more time to create some minimal examples that are much easier to read and understand. In the mean time I hope this helps! :-)

MLnick commented 8 years ago

Great, I will take a deeper look at that code. The general approach is along the lines of what my approach is (i.e. run iterations within mapPartitions).

On Tue, 15 Mar 2016 at 17:39 Rolf Jagerman notifications@github.com wrote:

That sounds great! [image: :+1:]

I have just now open sourced the LDA implementation here https://github.com/rjagerman/glintlda so you could take a look at that. It is based on some state-of-the-art LDA research and the many low-level optimizations, caches and buffers make the code base a bit hard to follow. I can give you some pointers:

The construction of the count-table matrix for the collapsed Gibbs sampler happens here https://github.com/rjagerman/glintlda/blob/master/src/main/scala/glintlda/LDAModel.scala#L133

val topicWordCounts = gc.matrix[Long](config.vocabularyTerms, config.topics, 2, %28x,y%29 => CyclicPartitioner%28x, y%29)

The solver uses Spark to map partitions of our RDD to resampled partitions (effectively doing an LDA iteration) here https://github.com/rjagerman/glintlda/blob/master/src/main/scala/glintlda/Solver.scala#L255

rdd = rdd.mapPartitionsWithIndex { case (id, it) => val s = solver(model, id) val partitionSamples = it.toArray s.fit(partitionSamples, t) partitionSamples.toIterator }

The s.fit(...) function internally uses the earlier constructed matrix on the parameter servers to perform an actual iteration of the algorithm and updates the counts on the matrix accordingly. It iterates over the matrix on the parameter servers here https://github.com/rjagerman/glintlda/blob/master/src/main/scala/glintlda/mh/MHSolver.scala#L44 and updates counts on the parameter server here https://github.com/rjagerman/glintlda/blob/master/src/main/scala/glintlda/mh/MHSolver.scala#L245 and here https://github.com/rjagerman/glintlda/blob/master/src/main/scala/glintlda/mh/MHSolver.scala#L275 .

We make extensive use of buffers for performance reasons, so this can obfuscate the code quite a bit. The locks that you see in the code act as a back pressure mechanism that limit the number of open requests to the parameter servers. This is necessary since the code is so fast that it could easily flood the parameter server by asynchronously sending requests.

I will be finished with my thesis in about 2 weeks, after which I'll have some more time to create some minimal examples that are much easier to read and understand. In the mean time I hope this helps! :-)

— You are receiving this because you authored the thread. Reply to this email directly or view it on GitHub: https://github.com/rjagerman/glint/issues/28#issuecomment-196885286

rjagerman commented 8 years ago

Also another hint that could be helpful: In my glint configuration I increased the Akka frame size and heart beat timeouts. This allows me to send much larger pull and push requests. My configuration file looks like this:

glint.master.host   = "127.0.0.1"
glint.master.port   = 13370
glint {
  server.akka.loglevel = "INFO"
  server.akka.stdout-loglevel = "INFO"
  client.akka.loglevel = "ERROR"
  client.akka.stdout-loglevel = "ERROR"
  master.akka.loglevel = "INFO"
  master.akka.stdout-loglevel = "INFO"
  master.akka.remote.log-remote-lifecycle-events = on
  server.akka.remote.log-remote-lifecycle-events = off
  client.akka.remote.log-remote-lifecycle-events = on
  client.timeout = 30 s

  master.akka.remote.transport-failure-detector.acceptable-heartbeat-pause = 120 s
  server.akka.remote.transport-failure-detector.acceptable-heartbeat-pause = 120 s
  client.akka.remote.transport-failure-detector.acceptable-heartbeat-pause = 120 s
  master.akka.remote.watch-failure-detector.acceptable-heartbeat-pause = 120 s
  server.akka.remote.watch-failure-detector.acceptable-heartbeat-pause = 120 s
  client.akka.remote.watch-failure-detector.acceptable-heartbeat-pause = 120 s

  server.akka.remote.netty.tcp.maximum-frame-size = 10240000b
  client.akka.remote.netty.tcp.maximum-frame-size = 10240000b
  server.akka.remote.netty.tcp.send-buffer-size = 20480000b
  client.akka.remote.netty.tcp.send-buffer-size = 20480000b
  server.akka.remote.netty.tcp.receive-buffer-size = 20480000b
  client.akka.remote.netty.tcp.receive-buffer-size = 20480000b
}
rjagerman commented 8 years ago

I've added slightly more comprehensive documentation at http://rjagerman.github.io/glint/. This might be of use to you as it includes a short section on spark integration and serialization (see the getting started guide).

I'm also currently working on getting some benchmarks for common tasks (e.g. logistic regression, SVMs, regression, all-reduce, etc.) so we can compare glint against spark and other frameworks. In time this will produce some more example code for these common tasks.

MLnick commented 8 years ago

Thanks Rolf! I have been swamped with Spark 2.0 work, but I'd like to get back to some PoC work with glint after that is done On Wed, 11 May 2016 at 16:49, Rolf Jagerman notifications@github.com wrote:

I've added slightly more comprehensive documentation at http://rjagerman.github.io/glint/. This might be of use to you as it includes a short section on spark integration and serialization (see the getting started guide).

I'm also currently working on getting some benchmarks for common tasks (e.g. logistic regression, SVMs, regression, all-reduce, etc.) so we can compare glint against spark and other frameworks. In time this will produce some more example code for these common tasks.

— You are receiving this because you authored the thread.

Reply to this email directly or view it on GitHub https://github.com/rjagerman/glint/issues/28#issuecomment-218483069

codlife commented 8 years ago

Hi @rjagerman ,After the look of your code ,I think it's clear to use the your API in the spark map or mapPartitions operations ,But I find it's a little hard to use your API ,unless i don't use spark aggregate ,such as I want to implement Gradient Descent. Thanks!

  while (!converged && i <= numIterations) {
      val bcWeights = data.context.broadcast(weights)
      // Sample a subset (fraction miniBatchFraction) of the total data
      // compute and sum up the subgradients on this subset (this is one map-reduce)
      val (gradientSum, lossSum, miniBatchSize) = data.sample(false, miniBatchFraction, 42 + i)
       .treeAggregate((BDV.zeros[Double](n), 0.0, 0L))(
          seqOp = (c, v) => {
            // c: (grad, loss, count), v: (label, features)
            val l = gradient.compute(v._2, v._1, bcWeights.value, Vectors.fromBreeze(c._1))
            (c._1, c._2 + l, c._3 + 1)
          },
          combOp = (c1, c2) => {
            // c: (grad, loss, count)
            (c1._1 += c2._1, c1._2 + c2._2, c1._3 + c2._3)
          })

      if (miniBatchSize > 0) {
        /**
         * lossSum is computed using the weights from the previous iteration
         * and regVal is the regularization value computed in the previous iteration as well.
         */
        stochasticLossHistory.append(lossSum / miniBatchSize + regVal)
        val update = updater.compute(
          weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble),
          stepSize, i, regParam)
        weights = update._1
        regVal = update._2

        previousWeights = currentWeights
        currentWeights = Some(weights)
        if (previousWeights != None && currentWeights != None) {
          converged = isConverged(previousWeights.get,
            currentWeights.get, convergenceTol)
        }
      } else {
        logWarning(s"Iteration ($i/$numIterations). The size of sampled batch is zero")
      }
      i += 1
    }