The first time I tried writing a 3D convolution where I MemReduce each slice of the cube into a 2D result, I wrote it this way:
val lineout = SRAM[T](coltile/colstride)
Foreach(input.rows by 1){row =>
MemReduce(lineout)(3 by 1 par 3){page =>
val lineout_local = SRAM[T](coltile/colstride)
val lb = LineBuffer[T](filter.rows, coltile)
val sr = RegFile[T](filter.rows, filter.cols)
lb load input(page, row, 0::input.cols)
Foreach(input.cols by 1){j => sr(i,*) <<= lb(i,j)
lineout(j) = Reduce(Reg[T](0.to[T]))(3 by 1, 3 by 1){(ii,jj) =>
sr(ii,jj) * filter(page,ii,jj)
}{_+_}
}
}{_+_}
output(row, 0::output.cols) store lineout
}
I wanted to compute one row at a time and store that row back. However, the LCA of the lb becomes the MemReduce rather than the row counter. This means we no longer use the line buffer to load one row while we reduce another row. The only correct way to write this is to either tile it and transpose the MemReduce with the Foreach counting rows, which takes more SRAM and probably does unnecessary reloading at the boundaries between tiles, or metaprogram it. Either way, it took a while of staring at waveforms to figure out why it didn't work so I think it would be impossible for an ordinary user to figure out. Any ideas for detecting kind of thing giving an error?
The first time I tried writing a 3D convolution where I MemReduce each slice of the cube into a 2D result, I wrote it this way:
I wanted to compute one row at a time and store that row back. However, the LCA of the lb becomes the MemReduce rather than the row counter. This means we no longer use the line buffer to load one row while we reduce another row. The only correct way to write this is to either tile it and transpose the MemReduce with the Foreach counting rows, which takes more SRAM and probably does unnecessary reloading at the boundaries between tiles, or metaprogram it. Either way, it took a while of staring at waveforms to figure out why it didn't work so I think it would be impossible for an ordinary user to figure out. Any ideas for detecting kind of thing giving an error?