hikettei / cl-waffe2

[Experimental] Graph and Tensor Abstraction for Deep Learning all in Common Lisp
https://hikettei.github.io/cl-waffe2/
MIT License
142 stars 6 forks source link
ai autodiff common-lisp deep-learning-framework math matrix-operations optimizing

Logo

Programmable Deep Learning Framework

Visit the docs »

Concepts · Install · Examples

CI

cl-waffe2

Concepts

⚠️ cl-waffe2 is still in the experimental stage. Things are subject to change, and APIs can be changed without warnings. DO NOT USE CL-WAFFE2 IN YOUR PRODUCT.

I actually have a repository cl-waffe(DEPRECATED UNSUPPORTED!) with a similar name. Note that cl-waffe2 is the latest one and all features are inherited from the old one.

cl-waffe2 provides fast, systematic, easy to optimize, customizable, device independent abstract matrix operations, and reverse mode tape-based Automatic Differentiation on Common Lisp. Plus, we also provide features for building and training neural network models, accelerated by a JIT Compiler.

Roughly speaking, this is a framework for the graph and tensor abstraction without overheads. All features provided here can be extended by users without exceptions - and with minimal code. In fact, cl-waffe2 is designed as the truly easiest framework to write extensions by users. There is no barrier between users and developers. There is no restriction imposed by the framework ignoring the developing language is limited to Common Lisp.

As of this writing, its abstraction layers are almost reaching the goals and working enough, but there is still a serious lack of backend functionality, and documentation. Contributions are welcome and I would appreciate if somebody who is interested in by project contact me: hikettei.

✨Features

🍃 Quicklook

As the simplest example, the build function traces and compiles the network from the endpoints of the computation nodes.

Example1. Compiling nodes

(let ((a (make-input `(A B) :A))
      (b (make-input `(A B) :B)))
  (let ((model (build (!sum (!mul a b)) :inputs `(:A :B))))
    (print model)
    ;; model is a compiled function: f(a b)
    (forward model (randn `(3 3)) (randn `(3 3)))))

;;<Compiled-Composite(allocated-p=NIL)
;;    forward     : forward(model A B) -> CPUTENSOR{FLOAT}(1 1)
;;    backward    : backward(model) -> t
;;    memory-pool : two tensor(s)
;;                   L {8.0e-6+((A B) x 4.0e-6)}MB
;;    inputs:
;;        A -> (A B)
;;        B -> (A B)
;;> 

;;{CPUTENSOR[float] :shape (1 1) -> :view (<(BROADCAST 1)> <(BROADCAST 1)>) -> :visible-shape (1 1) :named ChainTMP646587 
;;  ((1.0858848))
;;  :facet :input
;;  :requires-grad NIL
;;  :backward NIL} 

The advantages of using Common Lisp are numerous:

Example2. MLP Model

;; From https://github.com/hikettei/cl-waffe2/blob/master/examples/mnist/mlp.lisp
(defsequence MLP (in-features hidden-dim out-features
               &key (activation #'!relu))
         "Three Layers MLP Model"
         (LinearLayer in-features hidden-dim)
         (asnode activation)
         (LinearLayer hidden-dim hidden-dim)
         (asnode activation)
         (LinearLayer hidden-dim out-features))

(defun build-mlp-model (in-class out-class &key (hidden-size 256) (activation #'!relu) (lr 1e-3))
  (let* ((mlp (MLP in-class hidden-size out-class :activation activation))
     (lazy-loss (criterion #'softmax-cross-entropy
                   (call mlp
                     (make-input `(batch-size ,in-class) :X))
                   (make-input `(batch-size ,out-class) :Y)
                   :reductions (list #'!sum #'->scal)))
     (model     (build lazy-loss :inputs `(:X :Y))))
    (mapc (hooker x (Adam x :lr lr)) (model-parameters model))
    (values model mlp)))

(defun step-train-mlp (model x y)
  (let ((act-loss (forward model x y)))
    (backward model)
    (mapc #'call-optimizer! (model-parameters model))
    (/ (tensor-vec act-loss) 100)))

(defmethod accuracy ((model MLP) x y)
  (let* ((out   (!argmax (call model x)))
     (label (!argmax y))
     (total (proceed (->scal (!sum (A=B out label))))))
    (float (/ (tensor-vec total) (nth 0 (shape out))))))

Example3. reshape and transform

(!reshape (make-input `(N C H W) nil) (~ N C H W -> (* N C H) W))
(%transform (ax+b `(3) 1 0)[i] -> [~ i])

We also provide example projects here!

📈 Performance

Don't underestimate the power of lazy evaluation. Nodes are first converted to fully optimized IRs before doing forward and backward propagations.

Since cl-waffe2 is still under development, there are still many optimization techniques remains to be implemented. Even these benchmarks were measured under single-core but shows enough performance!

MLP

optimizers=Adam, hidden_size=256

n_epoch cl-waffe2 Keras PyTorch JAX
1 3.111s 3.662s 3.418 4.039s
10 32.437s 31.352s 28.403s 30.801s
100 304.864s 274.854s 338.031s 275.875s

optimizers=Adam hidden=512

n_epoch cl-waffe2 Keras PyTorch JAX
1 6.075s 7.55s 7.29s 6.90s
10 61.703s 56.283s 51.140s 65.682s

ResNet18

(Coming Soon...)

Text Generation

(Coming Soon...)

📕 Acknowledgments