mratsim / Arraymancer

A fast, ergonomic and portable tensor library in Nim with a deep learning focus for CPU, GPU and embedded devices via OpenMP, Cuda and OpenCL backends
https://mratsim.github.io/Arraymancer/
Apache License 2.0
1.33k stars 96 forks source link

Auto-fuse operations (alpha X + Y, alpha A * B + beta C ...) #31

Open mratsim opened 7 years ago

mratsim commented 7 years ago

Arraymancer can leverage the nim compiler and term-rewriting macros to automatically detect operations that can be fused.

This is probably similar to what Tensorflow is doing with their XLA compiler. See: https://developers.googleblog.com/2017/03/xla-tensorflow-compiled.html and the overview.

A term-rewriting example is already included with fusing toTensor + reshape operations: https://github.com/mratsim/Arraymancer/blob/05e2f41651950c290ec20d71ebd679bd3f74ea75/src/arraymancer/term_rewriting.nim#L40-L45

mratsim commented 7 years ago

There is a Julia library called Devectorize that can perform more complex fusion operations:

@devec r = a + b + c
@devec r = sin(a) + exp(a + 1.0) .* log(c)

.* being the element wise product in Julia

mratsim commented 7 years ago

What can be done is implementing an elementwise template that would be used like this:


elementwise:
  let out_of_place = alpha * A + exp(B) + beta * sin(C)
  in_place += sin(A) + exp(A + 1) + C

operations and scalar will be automatically applied element-wise to Tensors This would allow a nicer syntax than explicit broadcasting for some functions like sigmoid

Tentative broadcasting version:

proc sigmoid[T](x: Tensor[T]): Tensor[T] =
  result = 1.0f.bc |/| (1.0f.bc + exp(-x))

Tentative element-wise version:

proc sigmoid[T](x: Tensor[T]): Tensor[T] =
  result = newTensor(x.shape)
  elementwise:
    result = 1.0 / (1.0 + exp(-x))
mratsim commented 6 years ago

Implemented in https://github.com/mratsim/Arraymancer/blob/89a72b348f09b992006ebdb35d991ba80bd4c671/src/tensor/optim_ops_fusion.nim

It needs tests similar to https://github.com/mratsim/Arraymancer/issues/152 and more fusion for minus operator.