mratsim / laser

The HPC toolbox: fused matrix multiplication, convolution, data-parallel strided tensor primitives, OpenMP facilities, SIMD, JIT Assembler, CPU detection, state-of-the-art vectorized BLAS for floats and integers
Apache License 2.0
273 stars 15 forks source link

Lux refactor v3 - Frontend #34

Closed mratsim closed 5 years ago

mratsim commented 5 years ago

This is the third overhault of Lux AST (and hopefully the last).

History:

V1 introduced in https://github.com/numforge/laser/commit/a88a8581f0e2a72332efbfa8b6d60d9a33f21ea2 could only do elementwise operations with a proof of concept vectorization for 1 selected SIMD architecture (SSE, AVX, ...)

V2 introduced in https://github.com/numforge/laser/pull/29, split Lux into a high-level DSL, a frontend compiler that could symbolically execute that DSL, pass the produced trace to a backend compiler which would run lowering passes (to build loops from the tensors iteration domain) and then generate code.

The implementation highlighted a couple of issues in the design:

composability was bad because both the input A: InTensor, B: InTensor and their elements A[i,j] where of the same type, which means doing A + B and A[i,j] + B[i,j] would use the same proc instead of proc overloading.

The code generation part is tricky because of the lack of statement list, for example the AffineFor node had to return the nested lval symbol. This would have caused issue for implementing the loop fusion pass.

The user ergonomics was bad due to needing various initialization procs for domains, tensors and parameters.

Now instead of the following

proc foobar(a: LuxNode, b, c: LuxNode): LuxNode =
  var i, j: LuxNode 
  newLuxIterDomain(i, 0, a.shape(0))    
  newLuxIterDomain(j, 0, a.shape(1))    

  var bar: LuxNode
  newLuxMutTensor(bar)

  bar[i, j] = a[i, j] + b[i, j] + c[i, j]

  result = bar

We can use this

proc foobar(a: LuxNode, b, c: LuxNode): LuxNode =
  var i, j: LuxNode 
  var bar: LuxNode

  bar[i, j] = a[i, j] + b[i, j] + c[i, j]

  result = bar

The developer ergonomics is significantly improved with variant types now working like Nim or Lisp: almost everything is a seq[LuxNode], there is no need to remember single-use field names and AST traversal is more generics, and AST is easier to extend with no need to think of new field names to avoid collision. Also since it works like Nim AST, it's familiar to Nim devs (or would be helpful to learn Nim macros)

InTensors, MutTensors, LValTensors, MutFloats/LValFloats, MutInt, LValInt have all been replaced by the Function type. Like before it is keeping version information in a persistent data-structure, but instead of a tree structure it appends new versions to a sequence. Furthermore in Lux AST v2, we were forced to store AffineFor and other statements in the previous_version field. With the AST changes and the introduction of a StatementList this shouldn't be needed anymore.

Lastly, users were exposed to the low-level concept of LuxNode.