scalanlp / breeze

Breeze is/was a numerical processing library for Scala.
https://www.scalanlp.org
Apache License 2.0
3.45k stars 693 forks source link

Partition index #551

Open kindlychung opened 8 years ago

kindlychung commented 8 years ago

What is the best way to partition indices into two collections by a certain predicate?

For example (by wishful thinking), DenseVector(1.0, 2.0, 3.0, 4.0).partIndex(x => x > 2.0) should give (Seq(2, 3), Seq(0,1)).

dlwh commented 8 years ago

where(x > 2.0) would give you the first seq. There's nothing for giving the second, but that makes sense to add.

DonBeo commented 8 years ago

Hi, I would like to work on this issue.

I did not understand. Does where already support where for a generic condition (i.e. not !=0) ? I receive this error:

scala> v
res15: breeze.linalg.DenseVector[Double] = DenseVector(-1.0, 0.0, 1.0, 2.0)

scala> where(v > 0)
<console>:22: error: value > is not a member of breeze.linalg.DenseVector[Double]
       where(v > 0)
               ^

scala> where(v > 0.0)
<console>:22: error: value > is not a member of breeze.linalg.DenseVector[Double]
       where(v > 0.0)

In case let me know how you would like the function to be

asgeissler commented 8 years ago

Element-wise comparisons have an ':' in front of the operator. Here, that would be

where(v :> 0) and where(v :> 0.0)

Remember that the number has the same type as the elements in the vector (double, int, long, ...) I would recommend to read the Operations section on the cheat-sheet: https://github.com/scalanlp/breeze/wiki/Linear-Algebra-Cheat-Sheet

DonBeo commented 8 years ago

A possible implementation is:

def partIndex[T, K](v: DenseVector[T], condition: T => Boolean): (IndexedSeq[Int], IndexedSeq[Int]) = {
    val satisfyCondition = v.map(condition)
    (where(satisfyCondition), where(!satisfyCondition))
  }

  test("partIndex") {
    assert(where.partIndex(DenseVector(1, 2, 3, 4, 5, 6), (x:Int) => x%2 == 0)== (IndexedSeq(1, 3, 5),IndexedSeq(0, 2, 4)))
    assert(where.partIndex(DenseVector(1, 2, 3), (x:Int) => x > 100)== (IndexedSeq(),IndexedSeq(0, 1, 2)))
    assert(where.partIndex(DenseVector(1, 2, 3), (x:Int) => x< 100) == (IndexedSeq(0, 1, 2), IndexedSeq()))
    assert(where.partIndex(DenseVector(1, 2, 3), (x:Int) => x == 1) == (IndexedSeq(0), IndexedSeq(1, 2)))
  }

If this is fine I can make a pull request. I was thinking to write it as a method in where or in DenseArray.