stanford-ppl / spatial-lang

Spatial: "Specify Parameterized Accelerators Through Inordinately Abstract Language"
MIT License
99 stars 12 forks source link

Banking and Dispatch Strangeness #226

Open mattfel1 opened 6 years ago

mattfel1 commented 6 years ago

What is going on here? I was looking at Stefan's lenet and there are strange things with c2_RF:

There are 25 accesses to c2_RF, each parallelized by 4. In the final IR, there are 73 duplicates of c2_RF. It seems like the last 6 duplicates and 3 others randomly in the duplicates list each have four SRAMLoads dispatched to them, while the rest have just one SRAMLoad dispatched to them. My guess

Each of the duplicates with multiple things dispatched to them are BankedMemory(List(Banking(1,1,true)),1,false), while the ones with SRAMLoads on them are BankedMemory(List(Banking(1,8,true)),1,false), which seems completely wrong

The write parallelization on c2_RF is 8, hence the 8 banks. Each of the vectorized reads to c2_RF are 25 apart (i.e. 0, 25, 50, 75 are part of the same access in the app).

It looks like those 9 duplicates that have 4 accesses each are ones that happen to fall nicely inside of some math that looked at the LCM of 8, 4, 25, or some relationship between these numbers?

Below is the app that exposed this

import spatial.dsl._
import org.virtualized._

object hotfix extends SpatialApp {

  // override val target = targets.AWS_F1
  type T = FixPt[TRUE,_5,_11] // Signed

  val BATCH_SIZE = 17

  @virtualize
  def lenet_0e[T:Type:Num](
    i0: Array[T],
    c0: Array[T],
    c1: Array[T],
    c2: Array[T],
    c3: Array[T],
    c4: Array[T],
    c5: Array[T],
    c6: Array[T],
    c7: Array[T]
  ) : Array[T] = {

    val c0_DRAM = DRAM[T](20,1,32)  // TODO: Eventually will be 4D again, for now 2D loads
    val i0_DRAM = DRAM[T](BATCH_SIZE,1,28,32)
    val c1_DRAM = DRAM[T](32)
    val c2_DRAM = DRAM[T](50,512)
    val c3_DRAM = DRAM[T](64)
    val c4_DRAM = DRAM[T](500,800)

    val c5_DRAM = DRAM[T](512)
    val c6_DRAM = DRAM[T](10,512)

    val c7_DRAM = DRAM[T](32)
    val tmp5_DRAM = DRAM[T](BATCH_SIZE,32)

    val c0_reshaped = c0.reshape(20,1,25)
    val c0_new = (0::20, 0::1, 0::32){(i,j,k) => if (k < 25) c0_reshaped(i,j,k) else 0 };
    setMem(c0_DRAM, c0_new)

    val i0_reshaped = i0.reshape(BATCH_SIZE,1,28,28)
    val i0_new = (0::BATCH_SIZE, 0::1, 0::28, 0::32){(i,j,k,l) => if (l < 28) i0_reshaped(i,j,k,l) else 0 };
    setMem(i0_DRAM, i0_new)

    setMem(c1_DRAM, c1)

    val c2_reshaped = c2.reshape(50,500)
    val c2_new = (0::50, 0::512){(i,j) => if (j < 500) c2_reshaped(i,j) else 0 };
    setMem(c2_DRAM, c2_new)

    setMem(c3_DRAM, c3)

    setMem(c4_DRAM, c4.reshape(500,800))

    setMem(c5_DRAM, c5)

    val c6_reshaped = c6.reshape(10,500)
    val c6_new = (0::10, 0::512){(i,j) => if (j < 500) c6_reshaped(i,j) else 0 };
    setMem(c6_DRAM, c6_new)

    setMem(c7_DRAM, c7)

    val debug = DRAM[T](64)

    Accel {

        // Conv2D
        val tmp1_SRAM = SRAM[T](20,12,12)
        Foreach(20 by 1, 12 by 1, 12 by 1){(i,j,k) => tmp1_SRAM(i,j,k) = (k+ i + j).to[T]}
        val tmp2_SRAM = SRAM[T](50,4,4)
        val nr = 12
        val nc = 12
        val kr = 5
        val kc = 5
        val or = 8
        val oc = 8
        val d = 20
        val tmp2_SRAM_conv = SRAM[T](or, oc)
        val c2_RF = SRAM[T](512)
        c2_RF load c2_DRAM(0, 0::512 par 8)                                        // MATT SEE HERE
        MemReduce(tmp2_SRAM_conv)(d by 1 par 4) { inD_i => // in channels
          val result = SRAM[T](or, oc)
          Foreach(0 until or, 0 until oc par 1) { (r,c) =>

            val prod00 = tmp1_SRAM(inD_i, r.to[Index]+0.to[Index],c.to[Index]+0.to[Index]) * c2_RF(inD_i.to[Index]*25 + 0*5+0)
            val prod01 = tmp1_SRAM(inD_i, r.to[Index]+0.to[Index],c.to[Index]+1.to[Index]) * c2_RF(inD_i.to[Index]*25 + 0*5+1)
            val prod02 = tmp1_SRAM(inD_i, r.to[Index]+0.to[Index],c.to[Index]+2.to[Index]) * c2_RF(inD_i.to[Index]*25 + 0*5+2)
            val prod03 = tmp1_SRAM(inD_i, r.to[Index]+0.to[Index],c.to[Index]+3.to[Index]) * c2_RF(inD_i.to[Index]*25 + 0*5+3)
            val prod04 = tmp1_SRAM(inD_i, r.to[Index]+0.to[Index],c.to[Index]+4.to[Index]) * c2_RF(inD_i.to[Index]*25 + 0*5+4)
            val prod05 = tmp1_SRAM(inD_i, r.to[Index]+1.to[Index],c.to[Index]+0.to[Index]) * c2_RF(inD_i.to[Index]*25 + 1*5+0)
            val prod06 = tmp1_SRAM(inD_i, r.to[Index]+1.to[Index],c.to[Index]+1.to[Index]) * c2_RF(inD_i.to[Index]*25 + 1*5+1)
            val prod07 = tmp1_SRAM(inD_i, r.to[Index]+1.to[Index],c.to[Index]+2.to[Index]) * c2_RF(inD_i.to[Index]*25 + 1*5+2)
            val prod08 = tmp1_SRAM(inD_i, r.to[Index]+1.to[Index],c.to[Index]+3.to[Index]) * c2_RF(inD_i.to[Index]*25 + 1*5+3)
            val prod09 = tmp1_SRAM(inD_i, r.to[Index]+1.to[Index],c.to[Index]+4.to[Index]) * c2_RF(inD_i.to[Index]*25 + 1*5+4)
            val prod10 = tmp1_SRAM(inD_i, r.to[Index]+2.to[Index],c.to[Index]+0.to[Index]) * c2_RF(inD_i.to[Index]*25 + 2*5+0)
            val prod11 = tmp1_SRAM(inD_i, r.to[Index]+2.to[Index],c.to[Index]+1.to[Index]) * c2_RF(inD_i.to[Index]*25 + 2*5+1)
            val prod12 = tmp1_SRAM(inD_i, r.to[Index]+2.to[Index],c.to[Index]+2.to[Index]) * c2_RF(inD_i.to[Index]*25 + 2*5+2)
            val prod13 = tmp1_SRAM(inD_i, r.to[Index]+2.to[Index],c.to[Index]+3.to[Index]) * c2_RF(inD_i.to[Index]*25 + 2*5+3)
            val prod14 = tmp1_SRAM(inD_i, r.to[Index]+2.to[Index],c.to[Index]+4.to[Index]) * c2_RF(inD_i.to[Index]*25 + 2*5+4)
            val prod15 = tmp1_SRAM(inD_i, r.to[Index]+3.to[Index],c.to[Index]+0.to[Index]) * c2_RF(inD_i.to[Index]*25 + 3*5+0)
            val prod16 = tmp1_SRAM(inD_i, r.to[Index]+3.to[Index],c.to[Index]+1.to[Index]) * c2_RF(inD_i.to[Index]*25 + 3*5+1)
            val prod17 = tmp1_SRAM(inD_i, r.to[Index]+3.to[Index],c.to[Index]+2.to[Index]) * c2_RF(inD_i.to[Index]*25 + 3*5+2)
            val prod18 = tmp1_SRAM(inD_i, r.to[Index]+3.to[Index],c.to[Index]+3.to[Index]) * c2_RF(inD_i.to[Index]*25 + 3*5+3)
            val prod19 = tmp1_SRAM(inD_i, r.to[Index]+3.to[Index],c.to[Index]+4.to[Index]) * c2_RF(inD_i.to[Index]*25 + 3*5+4)
            val prod20 = tmp1_SRAM(inD_i, r.to[Index]+4.to[Index],c.to[Index]+0.to[Index]) * c2_RF(inD_i.to[Index]*25 + 4*5+0)
            val prod21 = tmp1_SRAM(inD_i, r.to[Index]+4.to[Index],c.to[Index]+1.to[Index]) * c2_RF(inD_i.to[Index]*25 + 4*5+1)
            val prod22 = tmp1_SRAM(inD_i, r.to[Index]+4.to[Index],c.to[Index]+2.to[Index]) * c2_RF(inD_i.to[Index]*25 + 4*5+2)
            val prod23 = tmp1_SRAM(inD_i, r.to[Index]+4.to[Index],c.to[Index]+3.to[Index]) * c2_RF(inD_i.to[Index]*25 + 4*5+3)
            val prod24 = tmp1_SRAM(inD_i, r.to[Index]+4.to[Index],c.to[Index]+4.to[Index]) * c2_RF(inD_i.to[Index]*25 + 4*5+4)

            val tree_level_0_00 = prod00 + prod01
            val tree_level_0_01 = prod02 + prod03
            val tree_level_0_02 = prod04 + prod05
            val tree_level_0_03 = prod06 + prod07
            val tree_level_0_04 = prod08 + prod09
            val tree_level_0_05 = prod10 + prod11
            val tree_level_0_06 = prod12 + prod13
            val tree_level_0_07 = prod14 + prod15
            val tree_level_0_08 = prod16 + prod17
            val tree_level_0_09 = prod18 + prod19
            val tree_level_0_10 = prod20 + prod21
            val tree_level_0_11 = prod22 + prod23
            val tree_level_0_12 = prod24

            val tree_level_1_00 = tree_level_0_00 + tree_level_0_01
            val tree_level_1_01 = tree_level_0_02 + tree_level_0_03
            val tree_level_1_02 = tree_level_0_04 + tree_level_0_05
            val tree_level_1_03 = tree_level_0_06 + tree_level_0_07
            val tree_level_1_04 = tree_level_0_08 + tree_level_0_09
            val tree_level_1_05 = tree_level_0_10 + tree_level_0_11
            val tree_level_1_06 = tree_level_0_12

            val tree_level_2_00 = tree_level_1_00 + tree_level_1_01
            val tree_level_2_01 = tree_level_1_02 + tree_level_1_03
            val tree_level_2_02 = tree_level_1_04 + tree_level_1_05
            val tree_level_2_03 = tree_level_1_06

            val tree_level_3_00 = tree_level_2_00 + tree_level_2_01
            val tree_level_3_01 = tree_level_2_02 + tree_level_2_03

            result(r, c) = tree_level_3_00 + tree_level_3_01
          }
          result
        }{_+_} // Reduce across in channels

        val flat = SRAM[T](or*oc)
        Foreach(or by 1, oc by 1){(i,j) => flat(i*oc+j) = tmp2_SRAM_conv(i,j)}

        debug store flat

      }

    getMem(debug)
  }

  @virtualize
  def main() {
    val i0 = loadCSV1D[T]("/home/shadjis/spatial/DEVELOP_spatial-lang/csv_lenetDNNW/data_in2.csv", "\n")
    val c0 = loadCSV1D[T]("/home/shadjis/spatial/DEVELOP_spatial-lang/csv_lenetDNNW/c0.csv", "\n") // conv1/Variable
    val c1 = loadCSV1D[T]("/home/shadjis/spatial/DEVELOP_spatial-lang/csv_lenetDNNW/c1.csv", "\n") // conv1/Variable_1
    val c2 = loadCSV1D[T]("/home/shadjis/spatial/DEVELOP_spatial-lang/csv_lenetDNNW/c2.csv", "\n") // conv2/Variable
    val c3 = loadCSV1D[T]("/home/shadjis/spatial/DEVELOP_spatial-lang/csv_lenetDNNW/c3.csv", "\n") // conv2/Variable_1
    val c4 = loadCSV1D[T]("/home/shadjis/spatial/DEVELOP_spatial-lang/csv_lenetDNNW/c4.csv", "\n") // fc1/Variable
    val c5 = loadCSV1D[T]("/home/shadjis/spatial/DEVELOP_spatial-lang/csv_lenetDNNW/c5.csv", "\n") // fc1/Variable_1
    val c6 = loadCSV1D[T]("/home/shadjis/spatial/DEVELOP_spatial-lang/csv_lenetDNNW/c6.csv", "\n") // fc2/Variable
    val c7 = loadCSV1D[T]("/home/shadjis/spatial/DEVELOP_spatial-lang/csv_lenetDNNW/c7.csv", "\n") // fc2/Variable_1
    val output = lenet_0e(i0, c0, c1, c2, c3, c4, c5, c6, c7)
    printArray(output, "got: ")
    // val output_no_extra = Array.tabulate(170){i => output(i/10, i%10)}
    // printArray(output_no_extra, "output")
    // val gold = loadCSV1D[T]("/home/shadjis/spatial/DEVELOP_spatial-lang/csv_lenetDNNW/data_out.csv", "\n")
    // printArray(gold, "gold")
    // val margin = 1.882.to[T]
    // val cksum = gold.zip(output_no_extra){(a,b) => abs(a-b) < margin}.reduce{_&&_}
    // println("PASS: " + cksum)
  }
}