TiarkRompf / virtualization-lms-core

A Framework for Runtime Code Generation and Compiled DSLs
http://scala-lms.github.com
BSD 3-Clause "New" or "Revised" License
324 stars 91 forks source link

Fix incorrect equality for ArrayNew and friends, add missing rewrites in PrimitiveOps #104

Closed alexeyr closed 7 years ago

alexeyr commented 8 years ago

Made against develop-0.9.x because develop-1.0.x currently looks a bit messy and I don't know if you are planning to rewrite it. I can rebase against a different branch if desired.

alexeyr commented 8 years ago

I also have a refactoring of PrimitiveOps in https://github.com/scalan/virtualization-lms-core/tree/refactor-primitive-ops, but for some reason the rewrites in TestStencil don't seem to work completely.

astojanov commented 8 years ago

Alexey, concerning 554e4adb4c923524590b43483085c323d63ba4a2, the Manifest is already passed as an implicit parameter in ArrayNew. If you want to specify the manifest explicitly, you can still do so without any changes to the initial codebase. For example:

Implicit version (creating array of integers, size 10):

val arr = ArrayNew[Int](Const(10))

Explicit version (creating array of integers, size 10):

val arr = ArrayNew(Const(10))(manifest[Int])

The implicits are in fact extra parameters in any method or class definition that they are part of. The Scala compiler attempts to match them automatically, and also provides the option to do this manually. The latter becomes a requirement in when dealing with several implicit methods or objects that makes the matching ambiguous.

Have a look here to learn more:

alexeyr commented 8 years ago

@astojanov I am well aware of this. The issue I am concerned about, as mentioned in the commit message, is that implicit parameters aren't considered for equality of case classes. So if you already have an ArrayNew[Int] and create an ArrayNew[Double] with the same length, the Sym for the ArrayNew[Int] will be returned from findOrCreateDefinition and you can get type mismatch in generated code (in this specific case it should be masked by reflectMutable in array_obj_new, but it still applies to other types). An alternative solution would be to override equals, if that's preferable.

alexeyr commented 8 years ago

Or, now that I think of it, to check type as well as equality in findDefinition, which would be a more general solution if it works. Let me check...

alexeyr commented 8 years ago

Yes, it does work as expected in our case and removes a warning in existing tests. @TiarkRompf The last commit combines lhs and mhs in TTP into a single field. It simplifies logic in a few places and shouldn't change any behavior, but can be excluded if you prefer.

astojanov commented 8 years ago

Consider the following code snippet:

import org.scalatest.FunSpec
import scala.virtualization.lms.common._

class TestArrayEqual extends FunSpec {

  val TestIR = new BaseFatExp with ArrayOpsExp
  import TestIR._

  def f[T:Manifest, U:Manifest] (): Rep[Unit] = {

    val (t, u) = (NewArray[T](Const(17)), NewArray[U](Const(17)))
    println("t = " + t)
    println("u = " + u)
    println("Comparing symbols: " + (t == u))

    (t, u) match {
      case (Def(Reflect(a@ArrayNew(_), _, _)), Def(Reflect(b@ArrayNew(_), _, _))) => {
        println("Comparing case classes: " + (a == b && a.m == b.m))
      }
    }

    val dt = findOrCreateDefinition(Reflect(ArrayNew[T](Const(17)), Alloc(), List()), t.pos)
    val du = findOrCreateDefinition(Reflect(ArrayNew[U](Const(17)), Alloc(), List()), u.pos)

    println("Comparing findOrCreateDefinition: " + (du == dt))
    println()
  }

  describe("TestArray") {
    reifyEffects(f[Int, Int]())
    reifyEffects(f[Int, Double]())
  }
}

This will result with the following output:

t = Sym(0)
u = Sym(1)
Comparing symbols: false
Comparing case classes: true
Comparing findOrCreateDefinition: true

t = Sym(3)
u = Sym(4)
Comparing symbols: false
Comparing case classes: false
Comparing findOrCreateDefinition: true

Obviously, in this case scenario findOrCreateDefinition "fails" in the second case because it returns the same definition for both ArrayNew[T] and ArrayNew[U]. However, note that, ArrayNew returns a mutable symbol. Applying findOrCreateDefinition or findDefinition is pointless since you might end up in a case scenario where you create 10 different Integer arrays. Which definition should then be returned? This is why the first case results with the same definition, although two different integer arrays are created.

findOrCreateDefinition is called whenever toAtom is invoked. And LMS is trying to perform common subexpression elimination by avoiding the creation of definitions which are already defined. reflectEffects however, does not use findOrCreateDefinition, since CSE is pointless when dealing with mutable expressions. findOrCreateDefinition or findDefinition are best used for immutable expressions.

If a scenario exists where findOrCreateDefinition fails on immutable expressions, then I would consider fixing findOrCreateDefinition as a requirement.

Its quite different scenario (which I am not considering here), if findDefintion is used to search through dependencies. In that case, I believe that this would be the wrong approach, since it create ambiguity when effectful statements are used.

alexeyr commented 8 years ago

If a scenario exists where findOrCreateDefinition fails on immutable expressions, then I would consider fixing findOrCreateDefinition as a requirement.

VectorLiteral (though that's in tests), ArrayFromSeq (reflectMutable is commented out there), ListNew with empty sequences as arguments. ImplicitConvert is another example.

astojanov commented 8 years ago

Well, ignoring VectorLiteral, since it is in test, the rest look pretty fine to me:

import org.scalatest.FunSpec
import scala.virtualization.lms.common._
import scala.collection.immutable.{List => ScalaList}

class TestArrayEqual extends FunSpec {

  val TestIR = new BaseFatExp with ArrayOpsExp with ListOpsExp with ImplicitOpsExp
  import TestIR._

  def f[T:Manifest, U:Manifest] (): Rep[Unit] = {

    val (t, u) = (NewArray[T](Const(17)), NewArray[U](Const(17)))
    println("t = " + t)
    println("u = " + u)
    println("Comparing symbols: " + (t == u))

    (t, u) match {
      case (Def(Reflect(a@ArrayNew(_), _, _)), Def(Reflect(b@ArrayNew(_), _, _))) => {
        println("Comparing case classes: " + (a == b && a.m == b.m))
      }
    }

    val dt = findOrCreateDefinition(Reflect(ArrayNew[T](Const(17)), Alloc(), ScalaList()), t.pos)
    val du = findOrCreateDefinition(Reflect(ArrayNew[U](Const(17)), Alloc(), ScalaList()), u.pos)

    println("dt: " + dt)
    println("du: " + du)

    println("Comparing findOrCreateDefinition: " + (du == dt))
    println()
  }

  def arrayFromSeq[T:Manifest, U:Manifest] (sT: Seq[Exp[T]], sU: Seq[Exp[U]]): Rep[Unit] = {

    val (t, u) = (array_obj_fromseq(sT), array_obj_fromseq(sU))
    println("t = " + t)
    println("u = " + u)
    println("Comparing symbols: " + (t == u))

    (t, u) match {
      case (Def(a@ArrayFromSeq(_)), Def(b@ArrayFromSeq(_))) => {
        println("Comparing case classes: " + (a == b && a.m == b.m))
      }
    }

    val dt = findOrCreateDefinition(ArrayFromSeq(sT), t.pos)
    val du = findOrCreateDefinition(ArrayFromSeq(sU), u.pos)

    println("dt: " + dt)
    println("du: " + du)

    println("Comparing findOrCreateDefinition: " + (du == dt))
    println()
  }

  def listNew[T:Manifest, U:Manifest] (sT: Seq[Exp[T]], sU: Seq[Exp[U]]): Rep[Unit] = {

    val (t, u) = (list_new(sT), list_new(sU))
    println("t = " + t)
    println("u = " + u)
    println("Comparing symbols: " + (t == u))

    (t, u) match {
      case (Def(a@ListNew(_)), Def(b@ListNew(_))) => {
        println("Comparing case classes: " + (a == b))
      }
    }

    val dt = findOrCreateDefinition(ListNew(sT), t.pos)
    val du = findOrCreateDefinition(ListNew(sU), u.pos)

    println("dt: " + dt)
    println("du: " + du)

    println("Comparing findOrCreateDefinition: " + (du == dt))
    println()
  }

  def testImplicitOps[T:Manifest, U:Manifest] (sT: Exp[T], sU: Exp[U])
    (implicit f1: T => String, f2: U => String): Rep[Unit] = {

    val (t, u) = (implicit_convert[T, String](sT), implicit_convert[U, String](sU))
    println("t = " + t)
    println("u = " + u)
    println("Comparing symbols: " + (t == u))

    (t, u) match {
      case (Def(a@ImplicitConvert(_)), Def(b@ImplicitConvert(_))) => {
        println("Comparing case classes: " + (a == b))
      }
    }

    val dt = findOrCreateDefinition(ImplicitConvert[T, String](sT), t.asInstanceOf[Exp[T]].pos)
    val du = findOrCreateDefinition(ImplicitConvert[U, String](sU), u.asInstanceOf[Exp[U]].pos)

    println("dt: " + dt)
    println("du: " + du)

    println("Comparing findOrCreateDefinition: " + (du == dt))
    println()
  }

  describe("TestArray") {

    println("Testing Mutable ArrayNew")
    println("===========================================")
    reifyEffects(f[Int, Int]())
    reifyEffects(f[Int, Double]())

    println("Testing Immutable ArrayFromSeq")
    println("===========================================")
    val seqA = Seq(fresh[Int], fresh[Int], fresh[Int])
    val seqB = Seq(fresh[Int], fresh[Int], fresh[Int])
    val seqC = Seq(fresh[Int], fresh[Int], fresh[Int])
    val seqD = Seq(fresh[Double], fresh[Double], fresh[Double])
    reifyEffects(arrayFromSeq(seqA, seqB))
    reifyEffects(arrayFromSeq(seqC, seqD))

    println("Testing Immutable ListNew")
    println("===========================================")
    val seqE = Seq(fresh[Int], fresh[Int], fresh[Int])
    val seqF = Seq(fresh[Int], fresh[Int], fresh[Int])
    val seqG = Seq(fresh[Int], fresh[Int], fresh[Int])
    val seqH = Seq(fresh[Double], fresh[Double], fresh[Double])
    reifyEffects(listNew(seqE, seqF))
    reifyEffects(listNew(seqG, seqH))

    println("Testing Immutable ImplicitOps")
    println("===========================================")
    val pInt = fresh[Int]
    val qInt = fresh[Int]
    val rInt = fresh[Int]
    val sDbl = fresh[Double]

    implicit val f1: (Int    => String) = (t: Int   ) => { t.toString }
    implicit val f2: (Double => String) = (t: Double) => { t.toString }

    reifyEffects(testImplicitOps(pInt, qInt))
    reifyEffects(testImplicitOps(rInt, sDbl))

  }
}

and the obtained output is:

Testing Mutable ArrayNew
===========================================
t = Sym(0)
u = Sym(1)
Comparing symbols: false
Comparing case classes: true
dt: TP(Sym(0),Reflect(ArrayNew(Const(17)),Summary(false,false,false,false,true,false,List(),List(),List(),List()),List()))
du: TP(Sym(0),Reflect(ArrayNew(Const(17)),Summary(false,false,false,false,true,false,List(),List(),List(),List()),List()))
Comparing findOrCreateDefinition: true

t = Sym(3)
u = Sym(4)
Comparing symbols: false
Comparing case classes: false
dt: TP(Sym(0),Reflect(ArrayNew(Const(17)),Summary(false,false,false,false,true,false,List(),List(),List(),List()),List()))
du: TP(Sym(0),Reflect(ArrayNew(Const(17)),Summary(false,false,false,false,true,false,List(),List(),List(),List()),List()))
Comparing findOrCreateDefinition: true

Testing Immutable ArrayFromSeq
===========================================
t = Sym(18)
u = Sym(19)
Comparing symbols: false
Comparing case classes: false
dt: TP(Sym(18),ArrayFromSeq(List(Sym(6), Sym(7), Sym(8))))
du: TP(Sym(19),ArrayFromSeq(List(Sym(9), Sym(10), Sym(11))))
Comparing findOrCreateDefinition: false

t = Sym(20)
u = Sym(21)
Comparing symbols: false
Comparing case classes: false
dt: TP(Sym(20),ArrayFromSeq(List(Sym(12), Sym(13), Sym(14))))
du: TP(Sym(21),ArrayFromSeq(List(Sym(15), Sym(16), Sym(17))))
Comparing findOrCreateDefinition: false

Testing Immutable ListNew
===========================================
t = Sym(34)
u = Sym(35)
Comparing symbols: false
Comparing case classes: false
dt: TP(Sym(34),ListNew(List(Sym(22), Sym(23), Sym(24))))
du: TP(Sym(35),ListNew(List(Sym(25), Sym(26), Sym(27))))
Comparing findOrCreateDefinition: false

t = Sym(36)
u = Sym(37)
Comparing symbols: false
Comparing case classes: false
dt: TP(Sym(36),ListNew(List(Sym(28), Sym(29), Sym(30))))
du: TP(Sym(37),ListNew(List(Sym(31), Sym(32), Sym(33))))
Comparing findOrCreateDefinition: false

Testing Immutable ImplicitOps
===========================================
t = Sym(42)
u = Sym(43)
Comparing symbols: false
Comparing case classes: false
dt: TP(Sym(42),ImplicitConvert(Sym(38)))
du: TP(Sym(43),ImplicitConvert(Sym(39)))
Comparing findOrCreateDefinition: false

t = Sym(44)
u = Sym(45)
Comparing symbols: false
Comparing case classes: false
dt: TP(Sym(44),ImplicitConvert(Sym(40)))
du: TP(Sym(45),ImplicitConvert(Sym(41)))
Comparing findOrCreateDefinition: false

as you can see, findOrCreateDefinition is able to match the case class defintion with the correct statement. This is because ImplicitConvert, ListNew and ArrayFromSeq are all immutable and no need for the manifest is required. I hope this explanation will bring some closure.

astojanov commented 8 years ago

Also note that in 80fdb755f0f2853b5faf069012da1d5a8e7111a2, adding the manifest as an implicit argument in findDefinition does not do anything, unless you actually use it.

alexeyr commented 8 years ago

Consider instead:

arrayFromSeq(Seq.empty[Exp[Int]], Seq.empty[Exp[Double]])

and

def testImplicitOps[T:Manifest, U:Manifest] (s: Exp[Int])
  (implicit f1: Int => T, f2: Int => T): Rep[Unit] = {

  val (t, u) = (implicit_convert[Int, T](s), implicit_convert[Int, U](s))
  println("t = " + t)
  println("u = " + u)
  println("Comparing symbols: " + (t == u))

  (t, u) match {
    case (Def(a@ImplicitConvert(_)), Def(b@ImplicitConvert(_))) => {
      println("Comparing case classes: " + (a == b))
    }
  }

  val dt = findOrCreateDefinition(ImplicitConvert[Int, T](s), t.asInstanceOf[Exp[T]].pos)
  val du = findOrCreateDefinition(ImplicitConvert[Int, U](s), u.asInstanceOf[Exp[U]].pos)

  println("dt: " + dt)
  println("du: " + du)

  println("Comparing findOrCreateDefinition: " + (du == dt))
  println()
}

testImplicitConvert[Long, Double](fresh[Int])(_.toLong, _.toDouble)
alexeyr commented 8 years ago

Also note that in 80fdb75, adding the manifest as an implicit argument in findDefinition does not do anything, unless you actually use it.

It's passed to infix_defines

astojanov commented 8 years ago

Yes, correct in the following case

arrayFromSeq(Seq.empty[Exp[Int]], Seq.empty[Exp[Double]])

you will obtain:

t = Sym(0)
u = Sym(0)
Comparing symbols: true
Comparing case classes: true
dt: TP(Sym(0),ArrayFromSeq(List()))
du: TP(Sym(0),ArrayFromSeq(List()))
Comparing findOrCreateDefinition: true

which is expected in a way by the interface provided by LMS. Technically an empty array of Int is pretty much equivalent to an empty array of Double if both are immutable. This is why for example scala.collection.immutable.Nil is typeless (or has type Nothing). It is also why:

println(Seq.empty[Exp[Int]] == Seq.empty[Exp[Double]])
println(Seq.empty[Int] == Seq.empty[Double])

will result to:

true
true

The type of the array will only matter if the array is in fact mutable, and you can not construct a mutable array using ArrayFromSeq. You should use ArrayNew instead. And that one does not undergo findOrCreateDefinition as discussed before.

Therefore, I would say, if you really need to deal with behaviour that is different than the one that LMS is expecting, a better approach would be to override findDefinition in your project, instead of LMS.

alexeyr commented 8 years ago

Technically an empty array of Int is pretty much equivalent to an empty array of Double if both are immutable. This is why for example scala.collection.immutable.Nil is typeless.

Not to the type system, they aren't! Arrays aren't covariant and you'll just get a type error if you try to use an empty Array[Int] as if it was an Array[Double]. Consider staging a function like

def f(x: Rep[Unit]): Rep[Array[Double]] = {
  // imagine lots of code which ends up calling array_obj_fromseq(Seq.empty[Int]) somewhere inside
  array_obj_fromseq(Seq.empty[Double])
}

The generated class will extend Unit => Array[Double] but apply will return Array[Int] and won't typecheck (actually it will return Array[Nothing] due to a bug in codegen, but that's a different issue).

For a non-empty array example: array_obj_from_seq(Seq(unit(1))) and array_obj_from_seq(Seq(unit(1.0))).

astojanov commented 8 years ago

I agree on this. Arrays are not covariant, and ideally if one wants to have an identical behaviour as the JVM, the representation of Arrays should deal with reifiable types.

However, since staged version of arrays, namely ArrayOps deals with generics, it must deal with non-reifiable types, and thus its implementation has a similar behaviour as some other containers (List or Seq).

I can not reproduce your latest example, but I do agree on the previous example concerning ImplicitConvert. Moving the manifest inside the case class makes sense, since you want to know to what type you convert to.

For the non-empty array example, this is exactly what happens, because it is expected. In Scala:

scala> Seq[Double](1.0) == Seq[Int](1)
res0: Boolean = true
alexeyr commented 8 years ago

For the non-empty array example, this is exactly what happens, because it is expected.

Yes, it's expected currently. But again this gives you ill-typed generated code from well-typed LMS code as follows:

val arrDouble: Rep[Array[Double]] = array_obj_from_seq(Seq(unit(1.0)))
val arrInt: Rep[Array[Int]] = array_obj_from_seq(Seq(unit(1)))
val someOtherArray: Rep[Array[A]] = // doesn't matter
val i: Rep[Int] = arrInt(0)
val j: Rep[Int] = someOtherArray(i)

(I've given the types explicitly, but that's how Scala would infer them as well). In generated code you should now get something like

val x1 = Array(1.0) // used for both arrDouble and arrInt
val x2 = x1(0) // should be Int, but actually Double
val x3 = ...
val x4 = x3(x2) // oops

Though this pull request doesn't fully solve this case. It probably is necessary (and may be sufficient) to move the Manifest inside Const as well.

astojanov commented 8 years ago

I finally see what you mean. And I agree, this needs to be fixed. I think the fix should not be limited only to findDefintion but also to other places such as:

Const(1) == Const(1.0)

(which is, in a way, also an expected behaviour because 1 == 1.0 in the JVM). I am not sure whats the right way and how should LMS treat this, but for my use-case I would prefer not to have those two as equal objects.

alexeyr commented 8 years ago

Yes, it seems distinguishing them is necessary. I've also added an option for explicit types to help debug type errors in generated code (not a mixin trait to avoid problems with mixing it and ScalaNestedCodegen in incorrect order).

alexeyr commented 8 years ago

I've split this into three pull requests (this, #105 and #106) to make the review simpler.

TiarkRompf commented 8 years ago

I finally had a chance to look at all the changes and discussion. Here is what I think:

TiarkRompf commented 8 years ago

Coming to think more about it, I wonder if the change to defines and findDefinition is really necessary. It seems to make a difference only for cases of a particular pattern (like List.empty). Maybe it would make more sense to handle those cases explicitly by including the Typ instance as a proper parameter in the Def (i.e. use case class ListEmpty[A](tp: Typ[A]) ?

alexeyr commented 8 years ago

That's what I did originally, but the problem there is that 1) you need to also consider places where you have List/Seq/Option[Rep[A]] or Rep[A]*; 2) it's easy to miss when new defs are added. Or, say,

ArrayFromSeq (reflectMutable is commented out there)

Now you would need to be careful to add/remove the Typ when making such changes.

On the contrary, changes to defines/findDefinition only have to be done once.

alexeyr commented 8 years ago

For defines and findDefinition, I agree that these should also check the types. But I would prefer to pass the manifest as an explicit parameter, instead of as an implicit. The reason is that sometimes we have to use findDefinition internally with type Any (for lack of a precise static type), and we have to guard against passing in Manifest[Any] by accident.

That shouldn't be a problem in 1.0.x, because there is no implicit Typ[Any], I think? It certainly can still be made explicit if preferred.

alexeyr commented 8 years ago

I agree that combining lhs and mhs is an improvement, but i'm afraid that it might break things in Delite or in the fusion branch that is yet to be merged. So I would hold off on that for the moment (but we can have it as a separate PR).

Moved to #110, though it will need to be updated after a decision on defines/findDefinition.