scalanlp / breeze

Breeze is/was a numerical processing library for Scala.
https://www.scalanlp.org
Apache License 2.0
3.45k stars 693 forks source link

Breeze "sortrows" function - first implementation below - help wanted? #770

Open Quafadas opened 4 years ago

Quafadas commented 4 years ago

One of the points I observer with Breeze vs other libraries, is that some "higher order" functions are missing. For example I wanted "sortrows" - similar to matlab.

I have a naive implementation below. I tried to add it to breeze itself by constructing a PR, but drowned in Generics, implicits and the difficulties of working inside a large project.

I'd love to see it added to the core library, simply because, for my use case, it makes seeing inside the data much easier. Unfortunately, the integration piece is beyond my skill, and I'm not clear what the qualityx requirements would be.

In case it helps someone with a similar requirement, there is someone willing to help me with contributing, or it's useful as an algorithmic outline.

 package helperFuns

  import breeze.linalg.{*, DenseMatrix, unique}

  object sortrows {

   def apply(dm: DenseMatrix[Double], cols: IndexedSeq[Int]): DenseMatrix[Double] = {

    val colOfInterest = cols.head
    val uniqueIndex = unique(dm(::, colOfInterest))
    val theTail = cols.tail
    val numColumns = dm.cols

    val naiveLog = false // possibly the worst logging framework ever.

    if (naiveLog) {
      println("Num columns : " + numColumns)
      println("Sort in column order: " + cols)
      println("Current sort column : " + colOfInterest)
    }

    val lotsOfSmallMatricies = for (i <- uniqueIndex.toArray) yield {
      //      Identify the subgroup to sort
      val tmp = dm(dm(::, colOfInterest) :== i, *).underlying.toDenseMatrix

      if (naiveLog) {
        println("sort val  : " + i)
        println("sort matrix  : " + tmp)
      }
      if (tmp.rows == 1) {
        //        Optimisation... if there's only 1 row, then just return it! Don't need all the complexity below
        tmp
      } else {
         //        println("Remaining columns to sort " + theTail)
        theTail.isEmpty match {
          case false => {
            val remaining = (0 until dm.cols).toSet.diff(Set(colOfInterest)).toList
            if (naiveLog) {
              println("Unsorted : ")
              println(remaining)
            }

            val remainingSort = tmp(*, remaining).underlying.toDenseMatrix
            if (naiveLog) {
              println("Chop out sorted column : \n" + remainingSort)
            }

            // Maintain the correct indicies of the columns we want to "sort", during the recursion
            val processTailIndicies = theTail.map(x => if(x > colOfInterest) x- 1 else x)
            if (naiveLog) {
              println("Remaining cols to sort : " + processTailIndicies)
            }

            val nextSort = sortrows(remainingSort, processTailIndicies)
            val nextSortCols = nextSort.cols
            val thisColumn = tmp(::, colOfInterest).toDenseMatrix.t

            colOfInterest match {
              case 0 => DenseMatrix.horzcat(thisColumn, nextSort)

              case `numColumns` => DenseMatrix.horzcat(nextSort, thisColumn)

              case _ => {
                if (naiveLog) {
                  println("\n" + nextSort )
                  println("0 until colOfInterest : \n" + (0 until colOfInterest))
                  println("colOfInterest until num cols: \n" + ( colOfInterest + 1 until numColumns))
                }
                DenseMatrix.horzcat(nextSort(::, 0 until colOfInterest), thisColumn, nextSort(::, colOfInterest until nextSort.cols ))
              }
            }
          }
          case true => tmp
        }
      }
    }
    lotsOfSmallMatricies.reduce(DenseMatrix.vertcat(_, _))
   } 
 }

Tests

package models

import java.io.File

import breeze.linalg.{DenseMatrix, DenseVector, convert, csvwrite}
import org.scalatest.{FlatSpec, Matchers}
import helperFuns.sortrows

class SortRowsSpec extends FlatSpec with Matchers {

  val groups = DenseMatrix(1.0, 1.0, 3.0, 2.0, 3.0, 1.0, 1.0)

  "sortrows" should "sort a column" in {

    val expected = DenseMatrix(1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 3.0)
    val out = sortrows(groups, Vector(0))

    out shouldEqual expected
  }

  it should "correctly sort subgroups in column preference order" in {

    val dv = DenseMatrix(0.0, 5.0, 6.0, 3.0, 7.0, 3.0, 1.0)
    val dv1 = DenseMatrix(2.0, 6.0, 7.0, 4.0, 8.0, 4.0, 1.0)
    val dv2 = DenseMatrix(10.0, 1.0, 4.0, 2.0, 3.0, 0.0, 10.0)

    val dm = DenseMatrix.horzcat(groups, dv, dv1, dv2)
    println(dm)

    val expected = DenseMatrix(
      (1.0, 3.0, 4.0, 0.0),
      (1.0, 5.0, 6.0, 1.0),
      (1.0, 1.0, 1.0, 10.0),
      (1.0, 0.0, 2.0, 10.0),
      (2.0, 3.0, 4.0, 2.0),
      (3.0, 7.0, 8.0, 3.0),
      (3.0, 6.0, 7.0, 4.0)
    )

    val sorted = sortrows(dm, Vector(0, 3, 2))

    sorted shouldEqual expected
  }

    it should "correctly sort subgroups in column preference order again" in {

      val dv = DenseMatrix(0.0, 5.0, 6.0, 3.0, 7.0, 3.0, 1.0)
      val dv1 = DenseMatrix(2.0, 6.0, 7.0, 4.0, 8.0, 4.0, 1.0)
      val dv2 = DenseMatrix(10.0, 1.0, 4.0, 2.0, 3.0, 0.0, 10.0)
      val dm = DenseMatrix.horzcat(groups, dv, dv1, dv2)
      println(dm)

      val expected2 = DenseMatrix(
        (1.0, 1.0, 1.0, 10.0),
        (1.0, 0.0, 2.0, 10.0),
        (1.0, 3.0, 4.0, 0.0),
        (2.0, 3.0, 4.0, 2.0),
        (1.0, 5.0, 6.0, 1.0),
        (3.0, 6.0, 7.0, 4.0),
        (3.0, 7.0, 8.0, 3.0),
      )

      val sorted2 = sortrows(dm, Vector(2, 0))

      sorted2 shouldEqual expected2

    }

  it should "deal with at least a non trivial number of rows" in {
    val r = scala.util.Random
    val dm = DenseMatrix.tabulate(1000000, 5){case (i, j) => i%5 + j + r.nextInt(5) }
//    println(dm)

    //    This takes sbout 2s on my PC, which isn't that bad!
     val sorted = sortrows(convert(dm, Double) , Vector(4,1,3,0,2))

//    val fileloc = new File("c:/temp/temp.csv")
//     csvwrite(fileloc, sorted, ',')
      sorted.rows shouldEqual 1000000
  }

}
dlwh commented 4 years ago

thanks for sending this my way! I'll try to make it breezy sometime soon and i'll followup.