Closed alexeyr closed 7 years ago
I also have a refactoring of PrimitiveOps
in https://github.com/scalan/virtualization-lms-core/tree/refactor-primitive-ops, but for some reason the rewrites in TestStencil
don't seem to work completely.
Alexey, concerning 554e4adb4c923524590b43483085c323d63ba4a2, the Manifest is already passed as an implicit parameter in ArrayNew
. If you want to specify the manifest explicitly, you can still do so without any changes to the initial codebase. For example:
Implicit version (creating array of integers, size 10):
val arr = ArrayNew[Int](Const(10))
Explicit version (creating array of integers, size 10):
val arr = ArrayNew(Const(10))(manifest[Int])
The implicits are in fact extra parameters in any method or class definition that they are part of. The Scala compiler attempts to match them automatically, and also provides the option to do this manually. The latter becomes a requirement in when dealing with several implicit methods or objects that makes the matching ambiguous.
Have a look here to learn more:
@astojanov I am well aware of this. The issue I am concerned about, as mentioned in the commit message, is that implicit parameters aren't considered for equality of case classes. So if you already have an ArrayNew[Int]
and create an ArrayNew[Double]
with the same length, the Sym
for the ArrayNew[Int]
will be returned from findOrCreateDefinition
and you can get type mismatch in generated code (in this specific case it should be masked by reflectMutable
in array_obj_new
, but it still applies to other types). An alternative solution would be to override equals
, if that's preferable.
Or, now that I think of it, to check type as well as equality in findDefinition
, which would be a more general solution if it works. Let me check...
Yes, it does work as expected in our case and removes a warning in existing tests. @TiarkRompf The last commit combines lhs
and mhs
in TTP
into a single field. It simplifies logic in a few places and shouldn't change any behavior, but can be excluded if you prefer.
Consider the following code snippet:
import org.scalatest.FunSpec
import scala.virtualization.lms.common._
class TestArrayEqual extends FunSpec {
val TestIR = new BaseFatExp with ArrayOpsExp
import TestIR._
def f[T:Manifest, U:Manifest] (): Rep[Unit] = {
val (t, u) = (NewArray[T](Const(17)), NewArray[U](Const(17)))
println("t = " + t)
println("u = " + u)
println("Comparing symbols: " + (t == u))
(t, u) match {
case (Def(Reflect(a@ArrayNew(_), _, _)), Def(Reflect(b@ArrayNew(_), _, _))) => {
println("Comparing case classes: " + (a == b && a.m == b.m))
}
}
val dt = findOrCreateDefinition(Reflect(ArrayNew[T](Const(17)), Alloc(), List()), t.pos)
val du = findOrCreateDefinition(Reflect(ArrayNew[U](Const(17)), Alloc(), List()), u.pos)
println("Comparing findOrCreateDefinition: " + (du == dt))
println()
}
describe("TestArray") {
reifyEffects(f[Int, Int]())
reifyEffects(f[Int, Double]())
}
}
This will result with the following output:
t = Sym(0)
u = Sym(1)
Comparing symbols: false
Comparing case classes: true
Comparing findOrCreateDefinition: true
t = Sym(3)
u = Sym(4)
Comparing symbols: false
Comparing case classes: false
Comparing findOrCreateDefinition: true
Obviously, in this case scenario findOrCreateDefinition
"fails" in the second case because it returns the same definition for both ArrayNew[T]
and ArrayNew[U]
. However, note that, ArrayNew
returns a mutable symbol. Applying findOrCreateDefinition
or findDefinition
is pointless since you might end up in a case scenario where you create 10 different Integer arrays. Which definition should then be returned? This is why the first case results with the same definition, although two different integer arrays are created.
findOrCreateDefinition
is called whenever toAtom
is invoked. And LMS is trying to perform common subexpression elimination by avoiding the creation of definitions which are already defined. reflectEffects
however, does not use findOrCreateDefinition
, since CSE is pointless when dealing with mutable expressions. findOrCreateDefinition
or findDefinition
are best used for immutable expressions.
If a scenario exists where findOrCreateDefinition
fails on immutable expressions, then I would consider fixing findOrCreateDefinition
as a requirement.
Its quite different scenario (which I am not considering here), if findDefintion
is used to search through dependencies. In that case, I believe that this would be the wrong approach, since it create ambiguity when effectful statements are used.
If a scenario exists where findOrCreateDefinition fails on immutable expressions, then I would consider fixing findOrCreateDefinition as a requirement.
VectorLiteral
(though that's in tests), ArrayFromSeq
(reflectMutable
is commented out there), ListNew
with empty sequences as arguments. ImplicitConvert
is another example.
Well, ignoring VectorLiteral
, since it is in test, the rest look pretty fine to me:
import org.scalatest.FunSpec
import scala.virtualization.lms.common._
import scala.collection.immutable.{List => ScalaList}
class TestArrayEqual extends FunSpec {
val TestIR = new BaseFatExp with ArrayOpsExp with ListOpsExp with ImplicitOpsExp
import TestIR._
def f[T:Manifest, U:Manifest] (): Rep[Unit] = {
val (t, u) = (NewArray[T](Const(17)), NewArray[U](Const(17)))
println("t = " + t)
println("u = " + u)
println("Comparing symbols: " + (t == u))
(t, u) match {
case (Def(Reflect(a@ArrayNew(_), _, _)), Def(Reflect(b@ArrayNew(_), _, _))) => {
println("Comparing case classes: " + (a == b && a.m == b.m))
}
}
val dt = findOrCreateDefinition(Reflect(ArrayNew[T](Const(17)), Alloc(), ScalaList()), t.pos)
val du = findOrCreateDefinition(Reflect(ArrayNew[U](Const(17)), Alloc(), ScalaList()), u.pos)
println("dt: " + dt)
println("du: " + du)
println("Comparing findOrCreateDefinition: " + (du == dt))
println()
}
def arrayFromSeq[T:Manifest, U:Manifest] (sT: Seq[Exp[T]], sU: Seq[Exp[U]]): Rep[Unit] = {
val (t, u) = (array_obj_fromseq(sT), array_obj_fromseq(sU))
println("t = " + t)
println("u = " + u)
println("Comparing symbols: " + (t == u))
(t, u) match {
case (Def(a@ArrayFromSeq(_)), Def(b@ArrayFromSeq(_))) => {
println("Comparing case classes: " + (a == b && a.m == b.m))
}
}
val dt = findOrCreateDefinition(ArrayFromSeq(sT), t.pos)
val du = findOrCreateDefinition(ArrayFromSeq(sU), u.pos)
println("dt: " + dt)
println("du: " + du)
println("Comparing findOrCreateDefinition: " + (du == dt))
println()
}
def listNew[T:Manifest, U:Manifest] (sT: Seq[Exp[T]], sU: Seq[Exp[U]]): Rep[Unit] = {
val (t, u) = (list_new(sT), list_new(sU))
println("t = " + t)
println("u = " + u)
println("Comparing symbols: " + (t == u))
(t, u) match {
case (Def(a@ListNew(_)), Def(b@ListNew(_))) => {
println("Comparing case classes: " + (a == b))
}
}
val dt = findOrCreateDefinition(ListNew(sT), t.pos)
val du = findOrCreateDefinition(ListNew(sU), u.pos)
println("dt: " + dt)
println("du: " + du)
println("Comparing findOrCreateDefinition: " + (du == dt))
println()
}
def testImplicitOps[T:Manifest, U:Manifest] (sT: Exp[T], sU: Exp[U])
(implicit f1: T => String, f2: U => String): Rep[Unit] = {
val (t, u) = (implicit_convert[T, String](sT), implicit_convert[U, String](sU))
println("t = " + t)
println("u = " + u)
println("Comparing symbols: " + (t == u))
(t, u) match {
case (Def(a@ImplicitConvert(_)), Def(b@ImplicitConvert(_))) => {
println("Comparing case classes: " + (a == b))
}
}
val dt = findOrCreateDefinition(ImplicitConvert[T, String](sT), t.asInstanceOf[Exp[T]].pos)
val du = findOrCreateDefinition(ImplicitConvert[U, String](sU), u.asInstanceOf[Exp[U]].pos)
println("dt: " + dt)
println("du: " + du)
println("Comparing findOrCreateDefinition: " + (du == dt))
println()
}
describe("TestArray") {
println("Testing Mutable ArrayNew")
println("===========================================")
reifyEffects(f[Int, Int]())
reifyEffects(f[Int, Double]())
println("Testing Immutable ArrayFromSeq")
println("===========================================")
val seqA = Seq(fresh[Int], fresh[Int], fresh[Int])
val seqB = Seq(fresh[Int], fresh[Int], fresh[Int])
val seqC = Seq(fresh[Int], fresh[Int], fresh[Int])
val seqD = Seq(fresh[Double], fresh[Double], fresh[Double])
reifyEffects(arrayFromSeq(seqA, seqB))
reifyEffects(arrayFromSeq(seqC, seqD))
println("Testing Immutable ListNew")
println("===========================================")
val seqE = Seq(fresh[Int], fresh[Int], fresh[Int])
val seqF = Seq(fresh[Int], fresh[Int], fresh[Int])
val seqG = Seq(fresh[Int], fresh[Int], fresh[Int])
val seqH = Seq(fresh[Double], fresh[Double], fresh[Double])
reifyEffects(listNew(seqE, seqF))
reifyEffects(listNew(seqG, seqH))
println("Testing Immutable ImplicitOps")
println("===========================================")
val pInt = fresh[Int]
val qInt = fresh[Int]
val rInt = fresh[Int]
val sDbl = fresh[Double]
implicit val f1: (Int => String) = (t: Int ) => { t.toString }
implicit val f2: (Double => String) = (t: Double) => { t.toString }
reifyEffects(testImplicitOps(pInt, qInt))
reifyEffects(testImplicitOps(rInt, sDbl))
}
}
and the obtained output is:
Testing Mutable ArrayNew
===========================================
t = Sym(0)
u = Sym(1)
Comparing symbols: false
Comparing case classes: true
dt: TP(Sym(0),Reflect(ArrayNew(Const(17)),Summary(false,false,false,false,true,false,List(),List(),List(),List()),List()))
du: TP(Sym(0),Reflect(ArrayNew(Const(17)),Summary(false,false,false,false,true,false,List(),List(),List(),List()),List()))
Comparing findOrCreateDefinition: true
t = Sym(3)
u = Sym(4)
Comparing symbols: false
Comparing case classes: false
dt: TP(Sym(0),Reflect(ArrayNew(Const(17)),Summary(false,false,false,false,true,false,List(),List(),List(),List()),List()))
du: TP(Sym(0),Reflect(ArrayNew(Const(17)),Summary(false,false,false,false,true,false,List(),List(),List(),List()),List()))
Comparing findOrCreateDefinition: true
Testing Immutable ArrayFromSeq
===========================================
t = Sym(18)
u = Sym(19)
Comparing symbols: false
Comparing case classes: false
dt: TP(Sym(18),ArrayFromSeq(List(Sym(6), Sym(7), Sym(8))))
du: TP(Sym(19),ArrayFromSeq(List(Sym(9), Sym(10), Sym(11))))
Comparing findOrCreateDefinition: false
t = Sym(20)
u = Sym(21)
Comparing symbols: false
Comparing case classes: false
dt: TP(Sym(20),ArrayFromSeq(List(Sym(12), Sym(13), Sym(14))))
du: TP(Sym(21),ArrayFromSeq(List(Sym(15), Sym(16), Sym(17))))
Comparing findOrCreateDefinition: false
Testing Immutable ListNew
===========================================
t = Sym(34)
u = Sym(35)
Comparing symbols: false
Comparing case classes: false
dt: TP(Sym(34),ListNew(List(Sym(22), Sym(23), Sym(24))))
du: TP(Sym(35),ListNew(List(Sym(25), Sym(26), Sym(27))))
Comparing findOrCreateDefinition: false
t = Sym(36)
u = Sym(37)
Comparing symbols: false
Comparing case classes: false
dt: TP(Sym(36),ListNew(List(Sym(28), Sym(29), Sym(30))))
du: TP(Sym(37),ListNew(List(Sym(31), Sym(32), Sym(33))))
Comparing findOrCreateDefinition: false
Testing Immutable ImplicitOps
===========================================
t = Sym(42)
u = Sym(43)
Comparing symbols: false
Comparing case classes: false
dt: TP(Sym(42),ImplicitConvert(Sym(38)))
du: TP(Sym(43),ImplicitConvert(Sym(39)))
Comparing findOrCreateDefinition: false
t = Sym(44)
u = Sym(45)
Comparing symbols: false
Comparing case classes: false
dt: TP(Sym(44),ImplicitConvert(Sym(40)))
du: TP(Sym(45),ImplicitConvert(Sym(41)))
Comparing findOrCreateDefinition: false
as you can see, findOrCreateDefinition
is able to match the case class defintion with the correct statement. This is because ImplicitConvert
, ListNew
and ArrayFromSeq
are all immutable and no need for the manifest is required. I hope this explanation will bring some closure.
Also note that in 80fdb755f0f2853b5faf069012da1d5a8e7111a2, adding the manifest as an implicit argument in findDefinition
does not do anything, unless you actually use it.
Consider instead:
arrayFromSeq(Seq.empty[Exp[Int]], Seq.empty[Exp[Double]])
and
def testImplicitOps[T:Manifest, U:Manifest] (s: Exp[Int])
(implicit f1: Int => T, f2: Int => T): Rep[Unit] = {
val (t, u) = (implicit_convert[Int, T](s), implicit_convert[Int, U](s))
println("t = " + t)
println("u = " + u)
println("Comparing symbols: " + (t == u))
(t, u) match {
case (Def(a@ImplicitConvert(_)), Def(b@ImplicitConvert(_))) => {
println("Comparing case classes: " + (a == b))
}
}
val dt = findOrCreateDefinition(ImplicitConvert[Int, T](s), t.asInstanceOf[Exp[T]].pos)
val du = findOrCreateDefinition(ImplicitConvert[Int, U](s), u.asInstanceOf[Exp[U]].pos)
println("dt: " + dt)
println("du: " + du)
println("Comparing findOrCreateDefinition: " + (du == dt))
println()
}
testImplicitConvert[Long, Double](fresh[Int])(_.toLong, _.toDouble)
Also note that in 80fdb75, adding the manifest as an implicit argument in findDefinition does not do anything, unless you actually use it.
It's passed to infix_defines
Yes, correct in the following case
arrayFromSeq(Seq.empty[Exp[Int]], Seq.empty[Exp[Double]])
you will obtain:
t = Sym(0)
u = Sym(0)
Comparing symbols: true
Comparing case classes: true
dt: TP(Sym(0),ArrayFromSeq(List()))
du: TP(Sym(0),ArrayFromSeq(List()))
Comparing findOrCreateDefinition: true
which is expected in a way by the interface provided by LMS. Technically an empty array of Int
is pretty much equivalent to an empty array of Double
if both are immutable. This is why for example scala.collection.immutable.Nil
is typeless (or has type Nothing
). It is also why:
println(Seq.empty[Exp[Int]] == Seq.empty[Exp[Double]])
println(Seq.empty[Int] == Seq.empty[Double])
will result to:
true
true
The type of the array will only matter if the array is in fact mutable, and you can not construct a mutable array using ArrayFromSeq
. You should use ArrayNew
instead. And that one does not undergo findOrCreateDefinition
as discussed before.
Therefore, I would say, if you really need to deal with behaviour that is different than the one that LMS is expecting, a better approach would be to override findDefinition
in your project, instead of LMS.
Technically an empty array of Int is pretty much equivalent to an empty array of Double if both are immutable. This is why for example scala.collection.immutable.Nil is typeless.
Not to the type system, they aren't! Arrays aren't covariant and you'll just get a type error if you try to use an empty Array[Int]
as if it was an Array[Double]
. Consider staging a function like
def f(x: Rep[Unit]): Rep[Array[Double]] = {
// imagine lots of code which ends up calling array_obj_fromseq(Seq.empty[Int]) somewhere inside
array_obj_fromseq(Seq.empty[Double])
}
The generated class will extend Unit => Array[Double]
but apply
will return Array[Int]
and won't typecheck (actually it will return Array[Nothing]
due to a bug in codegen, but that's a different issue).
For a non-empty array example: array_obj_from_seq(Seq(unit(1)))
and array_obj_from_seq(Seq(unit(1.0)))
.
I agree on this. Arrays are not covariant, and ideally if one wants to have an identical behaviour as the JVM, the representation of Arrays should deal with reifiable types.
However, since staged version of arrays, namely ArrayOps
deals with generics, it must deal with non-reifiable types, and thus its implementation has a similar behaviour as some other containers (List
or Seq
).
I can not reproduce your latest example, but I do agree on the previous example concerning ImplicitConvert
. Moving the manifest inside the case class makes sense, since you want to know to what type you convert to.
For the non-empty array example, this is exactly what happens, because it is expected. In Scala:
scala> Seq[Double](1.0) == Seq[Int](1)
res0: Boolean = true
For the non-empty array example, this is exactly what happens, because it is expected.
Yes, it's expected currently. But again this gives you ill-typed generated code from well-typed LMS code as follows:
val arrDouble: Rep[Array[Double]] = array_obj_from_seq(Seq(unit(1.0)))
val arrInt: Rep[Array[Int]] = array_obj_from_seq(Seq(unit(1)))
val someOtherArray: Rep[Array[A]] = // doesn't matter
val i: Rep[Int] = arrInt(0)
val j: Rep[Int] = someOtherArray(i)
(I've given the types explicitly, but that's how Scala would infer them as well). In generated code you should now get something like
val x1 = Array(1.0) // used for both arrDouble and arrInt
val x2 = x1(0) // should be Int, but actually Double
val x3 = ...
val x4 = x3(x2) // oops
Though this pull request doesn't fully solve this case. It probably is necessary (and may be sufficient) to move the Manifest
inside Const
as well.
I finally see what you mean. And I agree, this needs to be fixed. I think the fix should not be limited only to findDefintion
but also to other places such as:
Const(1) == Const(1.0)
(which is, in a way, also an expected behaviour because 1 == 1.0
in the JVM). I am not sure whats the right way and how should LMS treat this, but for my use-case I would prefer not to have those two as equal objects.
Yes, it seems distinguishing them is necessary. I've also added an option for explicit types to help debug type errors in generated code (not a mixin trait to avoid problems with mixing it and ScalaNestedCodegen
in incorrect order).
I've split this into three pull requests (this, #105 and #106) to make the review simpler.
I finally had a chance to look at all the changes and discussion. Here is what I think:
Const.equals
should be merged as is.defines
and findDefinition
, I agree that these should also check the types. But I would prefer to pass the manifest as an explicit parameter, instead of as an implicit. The reason is that sometimes we have to use findDefinition
internally with type Any
(for lack of a precise static type), and we have to guard against passing in Manifest[Any]
by accident.lhs
and mhs
is an improvement, but i'm afraid that it might break things in Delite or in the fusion branch that is yet to be merged. So I would hold off on that for the moment (but we can have it as a separate PR).Coming to think more about it, I wonder if the change to defines
and findDefinition
is really necessary. It seems to make a difference only for cases of a particular pattern (like List.empty
). Maybe it would make more sense to handle those cases explicitly by including the Typ
instance as a proper parameter in the Def
(i.e. use case class ListEmpty[A](tp: Typ[A])
?
That's what I did originally, but the problem there is that 1) you need to also consider places where you have List/Seq/Option[Rep[A]]
or Rep[A]*
; 2) it's easy to miss when new defs are added. Or, say,
ArrayFromSeq
(reflectMutable
is commented out there)
Now you would need to be careful to add/remove the Typ
when making such changes.
On the contrary, changes to defines
/findDefinition
only have to be done once.
For defines and findDefinition, I agree that these should also check the types. But I would prefer to pass the manifest as an explicit parameter, instead of as an implicit. The reason is that sometimes we have to use findDefinition internally with type Any (for lack of a precise static type), and we have to guard against passing in Manifest[Any] by accident.
That shouldn't be a problem in 1.0.x, because there is no implicit Typ[Any]
, I think? It certainly can still be made explicit if preferred.
I agree that combining lhs and mhs is an improvement, but i'm afraid that it might break things in Delite or in the fusion branch that is yet to be merged. So I would hold off on that for the moment (but we can have it as a separate PR).
Moved to #110, though it will need to be updated after a decision on defines
/findDefinition
.
Made against
develop-0.9.x
becausedevelop-1.0.x
currently looks a bit messy and I don't know if you are planning to rewrite it. I can rebase against a different branch if desired.