darrenjw / scala-glm

Scala library for fitting linear and generalised linear statistical models
Apache License 2.0
28 stars 4 forks source link

Unexpectedly high memory usage for large models. #1

Open espears4sq opened 6 years ago

espears4sq commented 6 years ago

When testing how well the LogisticGlm model scales with a large toy data set, I am finding on my local machine (16 GB RAM) that I hit out of memory errors even for fairly tiny problem sizes.

Here is some example code to make a toy logistic regression:

import breeze.linalg.{DenseMatrix, DenseVector}
import scalaglm.{Glm, LogisticGlm}

object glm extends App {

  // Helper function to map synthetically generated data into
  // training labels of a logistic regression.
  def logistic_fn(x: Double): Double = {
    1.0 / (1.0 + math.exp(-x))
  }

  def fit_logistic(): Glm = {
    // Parameters for creating synthetic data
    val r = new scala.util.Random(0)
    val normal = breeze.stats.distributions.Gaussian(0, 1)

    // Define problem size num_observations x num_features
    val num_observations = 1000000
    val num_features = 50

    val beta = DenseVector.rand(num_features) :* 5.0
    val names = for (i <- 1 to num_features) yield "var_%d".format(i)
    println("True coefficients:")
    println(beta(0 to 10))

    // Create synthetic logistic regression data set.
    val x = DenseMatrix.rand(num_observations, num_features, normal)
    x(::, 0) := 1.0
    val true_logits = x * beta
    val y = true_logits map logistic_fn map {p_i => (if (r.nextDouble < p_i) 1.0 else 0.0)}

    val t1 = System.nanoTime
    val g = Glm(y, x, names, LogisticGlm, addIntercept=false, its=1000)
    println("Elapsed %4.2f for training model".format((System.nanoTime - t1) / 1e9d))
    return g
  }
}

With this problem size (1 million observations for 50 features), I immediately get an OOM error:

scala> val g = glm.fit_logistic()
True coefficients:
DenseVector(2.78135510778451, 3.6818164882958326, 3.4840289537745948, 4.912012391491977, 2.907467492064324, 0.7532367248769811, 4.496847165217405, 0.20064910613956877, 4.855909891445109, 0.6049146229107971, 4.8162668734131895)
Aug 02, 2018 11:03:48 AM com.github.fommil.jni.JniLoader liberalLoad
INFO: successfully loaded /var/folders/0p/vx1f5tn93z1dc8pzk21g5nx40000gn/T/jniloader2218725777137246063netlib-native_system-osx-x86_64.jnilib
java.lang.OutOfMemoryError: Java heap space
  at scala.reflect.ManifestFactory$DoubleManifest.newArray(Manifest.scala:153)
  at scala.reflect.ManifestFactory$DoubleManifest.newArray(Manifest.scala:151)
  at breeze.linalg.DenseMatrix$.zeros(DenseMatrix.scala:345)
  at breeze.linalg.DenseMatrix$$anon$33.$anonfun$apply$2(DenseMatrix.scala:823)
  at breeze.linalg.DenseMatrix$$anon$33.$anonfun$apply$2$adapted(DenseMatrix.scala:820)
  at breeze.linalg.DenseMatrix$$anon$33$$Lambda$5324/324878705.apply(Unknown Source)
  at scala.collection.immutable.Range.foreach(Range.scala:156)
  at breeze.linalg.DenseMatrix$$anon$33.apply(DenseMatrix.scala:820)
  at breeze.linalg.DenseMatrix$$anon$33.apply(DenseMatrix.scala:817)
  at breeze.linalg.BroadcastedColumns$$anon$4.apply(BroadcastedColumns.scala:91)
  at breeze.linalg.BroadcastedColumns$$anon$4.apply(BroadcastedColumns.scala:89)
  at breeze.linalg.ImmutableNumericOps.$times(NumericOps.scala:149)
  at breeze.linalg.ImmutableNumericOps.$times$(NumericOps.scala:148)
  at breeze.linalg.BroadcastedColumns.$times(BroadcastedColumns.scala:30)
  at scalaglm.Irls$.IRLS(Glm.scala:243)
  at scalaglm.Glm.<init>(Glm.scala:87)
  at glm$.fit_logistic(glm.scala:30)
  ... 15 elided

This is a fairly small problem instance. If I generate the data set with numpy for example and serialize to a binary file on disk, it is less than 5 GB. For example, there is no trouble loading this data and fitting the model (even with the standard error calculations) in the statsmodels or scikit-learn libraries for Python.

What are the root causes for such unexpectedly high memory usage in scala-glm?

A secondary question is how to monitor convergence for this large data. I can increase the iterations, but there is no feedback-per-iteration during model fitting to give an update on whether the fit seems to be converging or not.

darrenjw commented 6 years ago

Thanks for your feedback. For really big problems you probably want to use Spark. But there's no reason why this code should not work for models that run fine in python/statsmodels. Can I check that you are allocating plenty of heap memory? The JVM default is piffling. eg. if you are running using sbt, are you starting sbt with extra heap, like "sbt -mem 12000" to start with 12GB of heap?

darrenjw commented 6 years ago

Regarding your second question, there is no code for this currently. There should be a "verb" or "debug" option when calling the function that will print (or log) diagnostic information while running. I'll file a separate issue for that.

espears4sq commented 6 years ago

Thanks for your reply!

Even if I set sbt -mem 14500, I am seeing OOM errors. I am wondering if it has to do with whatever matrix algorithms you are using for calculating the coefficient covariance matrix for the standard errors. That's the only aspect of this I could guess would cause almost a doubling in memory usage from the basic size of the large design matrix.

darrenjw commented 6 years ago

Does it run OK if (say) you halve the number of observations and covariates? If so, I suspect the problem is just that I haven't tried to do in-place operations or similar tricks from numerical linear algebra to avoid copying and allocation. I could easily imagine that it is possible to save a factor of 2 or 4 on memory by being a bit more careful with the numerics.

espears4sq commented 6 years ago

It runs as expected for smaller numbers of observations and samples. I believe the time complexity for most optimizer algorithms for this problem is quadratic in the number of features though, so going from e.g. 100 to 500 feature columns causes a slowdown. I have no idea how the memory consumption grows as a function of those parameters for a non-optimized covariance calculation vs. an optimized one, though.

At any rate, this is still quite a nice library. As far as I can tell, I think only Spark and scala-glm provide standard errors for logistic regression estimates in the scala ecosystem. Since this is such a fundamental need for any regression fit, I expect scala-glm could attract a lot of attention.