uncomplicate / neanderthal

Fast Clojure Matrix Library
http://neanderthal.uncomplicate.org
Eclipse Public License 1.0
1.07k stars 56 forks source link

Support pointwise comparison, sign, max #32

Closed whilo closed 7 years ago

whilo commented 7 years ago

It would be beneficial to be able to set all non-positive entries of a vector, e.g. for ReLU activations in neural networks. This can be achieved by applying max(0, x) pointwise. Similarly sign, <, >, = might prove handy if they return 1.0 for truth and 0.0 else. In R, Python, Matlab, ... this is the way to do selection of values in variables in vectorized fashion.

blueberry commented 7 years ago

I have added pointwise min and max functions.

Other cases:

  1. <,>,= not added. Many reasons, including: too fine grained (especially for GPU), not very suitable for floating-point operations (=), and not very useful on themselves. Write coarse grained kernels for the NN operations that you need instead of enqueuing multiple trivial kernels.
  2. (max 0 x) not added. Not vectorizable and too specific.

Basically, all of these operations can already be described as a combination of a few existing functions. For example, signum can be computed with (fmod x (abs x)), but I would not recommend that anyway; you have to write specialized GPU kernels for those deep learning operations you care about if you want good performance.

whilo commented 7 years ago

Thanks a lot!

<, >, = are not just for deep-learning, this is the way to normally calculate index sets over arbitrary tensors. I have used them in every numeric environment where vectorization is normal for very basic things (Matlab, R, Python), e.g. to sum all elements larger than some threshold. Let's say you want to select all non-zero values with some custom epsilon precision for your algorithm, you would do

(let [indexes (vm/> v eps) ;; array with [0.0 1.0 ... 0.0]
       non-zero (vm/mul v indexes)]
  ...)

I can do (defn < [x y] (max (signum (- x y)) 0)) though.

I get your point for custom deep learning kernels, you are right that activation functions are probably better implemented in one kernel. On the other hand staying close to device memory (instead of fmap), should by far be the most important step performance-wise, as the matrix multiplication dominates everything anyway. I think it is helpful to educate people along a path towards low-level optimization by letting them improve step-wise before they start writing ClojureCL. Often attempts are still exploratory and not yet worth the effort, but accessing GPU memory can kill even the exploratory feedback loop.

blueberry commented 7 years ago

That is not even low-level optimization, but the basics of programming massively parallel processors. Matrix multiplication dominates only when you do not kill performance by calling trivial kernels many times.

whilo commented 7 years ago

I would need the sign function though (signum, 1 for x > 0, 0 for x = 0, and -1 for x < 0). It has similar problems like = for 0, not sure whether it is better to implement <,>,= or signum. But if you think I should stick with fmap for now, then I can live with it. I just wanted to mention that these will pop up often.