Open kindlychung opened 8 years ago
where(x > 2.0)
would give you the first seq. There's nothing for giving the second, but that makes sense to add.
Hi, I would like to work on this issue.
I did not understand. Does where already support where for a generic condition (i.e. not !=0) ? I receive this error:
scala> v
res15: breeze.linalg.DenseVector[Double] = DenseVector(-1.0, 0.0, 1.0, 2.0)
scala> where(v > 0)
<console>:22: error: value > is not a member of breeze.linalg.DenseVector[Double]
where(v > 0)
^
scala> where(v > 0.0)
<console>:22: error: value > is not a member of breeze.linalg.DenseVector[Double]
where(v > 0.0)
In case let me know how you would like the function to be
Element-wise comparisons have an ':' in front of the operator. Here, that would be
where(v :> 0)
and
where(v :> 0.0)
Remember that the number has the same type as the elements in the vector (double, int, long, ...) I would recommend to read the Operations section on the cheat-sheet: https://github.com/scalanlp/breeze/wiki/Linear-Algebra-Cheat-Sheet
A possible implementation is:
def partIndex[T, K](v: DenseVector[T], condition: T => Boolean): (IndexedSeq[Int], IndexedSeq[Int]) = {
val satisfyCondition = v.map(condition)
(where(satisfyCondition), where(!satisfyCondition))
}
test("partIndex") {
assert(where.partIndex(DenseVector(1, 2, 3, 4, 5, 6), (x:Int) => x%2 == 0)== (IndexedSeq(1, 3, 5),IndexedSeq(0, 2, 4)))
assert(where.partIndex(DenseVector(1, 2, 3), (x:Int) => x > 100)== (IndexedSeq(),IndexedSeq(0, 1, 2)))
assert(where.partIndex(DenseVector(1, 2, 3), (x:Int) => x< 100) == (IndexedSeq(0, 1, 2), IndexedSeq()))
assert(where.partIndex(DenseVector(1, 2, 3), (x:Int) => x == 1) == (IndexedSeq(0), IndexedSeq(1, 2)))
}
If this is fine I can make a pull request. I was thinking to write it as a method in where
or in DenseArray
.
What is the best way to partition indices into two collections by a certain predicate?
For example (by wishful thinking),
DenseVector(1.0, 2.0, 3.0, 4.0).partIndex(x => x > 2.0)
should give(Seq(2, 3), Seq(0,1))
.