janet-lang / spork

Various Janet utility modules - the official "Contrib" library.
MIT License
117 stars 35 forks source link

Feature: Basic matrix operations until QR and SVD algorithms #172

Closed hyiltiz closed 7 months ago

hyiltiz commented 7 months ago
sogaiu commented 7 months ago

I was going to mention using ; (splice) vs apply, for example, instead of:

(defn dot [v1 v2]
  (apply + (map * v1 v2)))

expressing as:

(defn dot [v1 v2]
  (+ ;(map * v1 v2)))

but when I got to comparing the speed difference, I got the sense that apply was slightly faster (added some newlines for readability):

$ janet
Janet 1.33.0-23b0fe9f linux/x64/gcc - '(doc)' for help

repl:1:> (do (use spork/test) :done)
:done

repl:2:> (defn dot1 [v1 v2] (apply + (map * v1 v2)))
<function dot1>

repl:3:> (defn dot2 [v1 v2] (+ ;(map * v1 v2)))
<function dot2>

repl:4:> (timeit-loop [:timeout 10] "dot1" (dot1 [2 3 8 9] [1 0 -2 7]))
dot1 10.000s, 0.9481µs/body
nil

repl:5:> (timeit-loop [:timeout 10] "dot2" (dot2 [2 3 8 9] [1 0 -2 7]))
dot2 10.000s, 1.011µs/body
nil

That got me thinking though...may be using loop with an accumulator could be faster:

repl:6:> (defn dot3 [v1 v2] (var t 0) (loop [i :in v1 j :in v2] (+= t (* i j))) t)
<function dot3>

repl:7:> (timeit-loop [:timeout 10] "dot3" (dot3 [2 3 8 9] [1 0 -2 7]))
dot3 10.000s, 0.4623µs/body
nil

I did something similar for subtract using seq:

$ janet
Janet 1.33.0-23b0fe9f linux/x64/gcc - '(doc)' for help
repl:1:> (do (use spork/test) :done)
:done

repl:2:> (defn subtract1 [v1 v2] (map - v1 v2))
<function subtract1>

repl:3:> (defn subtract2 [v1 v2] (seq [i :in v1 j :in v2] (- i j)))
<function subtract2>

repl:4:> (timeit-loop [:timeout 10] (subtract1 [2 3 8 9] [1 0 -2 7]))
Elapsed time: 10.000s, 0.8114µs/body
nil

repl:5:> (timeit-loop [:timeout 10] (subtract2 [2 3 8 9] [1 0 -2 7]))
Elapsed time: 10.000s, 0.7926µs/body
nil

Not that different in this case I guess.

But using array/push:

repl:8:> (defn subtract3 [v1 v2] (def res @[]) (for i 0 (length v1) (array/push res (- (get v1 i) (get v2 i)))) res)
<function subtract3>

repl:9:> (timeit-loop [:timeout 10] (subtract3 [2 3 8 9] [1 0 -2 7]))
Elapsed time: 10.000s, 0.2809µs/body
nil

Hmm, may be using for instead of loop for dot would be faster...looks like it:

repl:8:> (defn dot4 [v1 v2] (var t 0) (for i 0 (length v1) (+= t (* (get v1 i) (get v2 i)))) t)
<function dot4>

repl:9:> (timeit-loop [:timeout 10] "dot4" (dot4 [2 3 8 9] [1 0 -2 7]))
dot4 10.000s, 0.1531µs/body
nil

Don't know if this kind of thing is worth it, but FWIW.

sogaiu commented 7 months ago

Not sure how I feel about:

(defn sign [x]
  (if (>= x 0) 1 -1))

If we're going for this, may be we want 0 to be returned for the case of x being 0?

sogaiu commented 7 months ago

I noticed there are now 2 functions named dot:

bakpakin commented 7 months ago

Thanks @hyiltiz , looks interesting!

@sogaiu I wouldn't focus too much on small optimizations - generally, explicitly using loop instead of map will be faster but either way is fine.

hyiltiz commented 7 months ago

Thank you so much for all the comments! I worked on this to introduce basic matrix utilities to Janet, hence started adding anything that is needed until QR and SVD was possible. There are surely a lot of optimizations that are possible, as @sogaiu has identified. I am more than happy to adopt those changes.

A big question is that spork/math adopted row-first convension for matrices (an array is understood as a row vector, rather than column vector, hence a linear equation is expresssed as b = xA, whereas the convension is b = Ax assuming an array represents column vector). If there is not much external dependencies, it might make sense to deprecate old convension; I'll have to adjust this PR as well.

In a minor note, the earlier dot seemed too complex for what dot does; unless there is significant benefits in efficiency, I think it is better to rely on compiler to provide optimizations

sogaiu commented 7 months ago

In general I'm not a fan of trying to do the kinds of optimizations I experimented with above because:

To explain some of the motivation for earlier optimization comments...

primo-ppcg commented 7 months ago

Since I've been asked to comment:


Perhaps explicitly transposing 1xN or Nx1 matrices would be more clear?
- `transpose` `op` `transpose` should be unnecessary in most cases. For example, `fliplr` could be defined as just `(map reverse m)`.
- There's a little bit of code duplication (e.g. `matmul` redefines `transpose` in place).
- I'm not familiar enough with the QR or SVD algorithms used to be able to comment on the implementations. Is O(n^3) the fastest algorithm known?
hyiltiz commented 7 months ago

The QR algorithm is the (one of) the best known; the SVD is (one of) the simplest one given QR, and there are a lot of other algorithms designed for efficiency. Those are usually complex enough that it is usually not re-implemented but simply wrapped/called from the LAPACK/BLAS/eigen libraries. I think this is a simple and direct first step for a purely janet-based implementation.

Thank you so much all for the informative and detailed feedback! Given the valuable feedback above, I'll revise the draft and see if we can get it closer to a PR that we can consider for a merge.

hyiltiz commented 7 months ago

New changes:

sogaiu commented 7 months ago

Some minor comments:

hyiltiz commented 7 months ago

Adjusted for all of the points. Some of the points refer to functions such as sop and binomial-distribution that are not part of this PR but is part of the module touched by this PR. Happy to fix.

sogaiu commented 7 months ago

Ah, sorry about those unrelated bits. Not sure what to do about those. Will think on it.

Hope to look in more detail soon, but wondering about this change. I'm not too familiar with GH CI, but is this intentional?

hyiltiz commented 7 months ago

Yes; that change should allow it (and all other incoming PRs) to run the linting and checks without having repo admin to click Approve. Less friction for PR contributions. Probably should've been a separate PR, but it is just a single line of change so may as well...

sogaiu commented 7 months ago

There is a small issue with the tests.

The following diff may fix it:

diff --git a/test/suite-math.janet b/test/suite-math.janet
index 1950c79..edaa477 100644
--- a/test/suite-math.janet
+++ b/test/suite-math.janet
@@ -420,90 +420,90 @@
       res-svd (svd m3)
       U (res-svd :U)
       S (res-svd :S)
-      V (res-svd :V)
-      (assert (deep= m23 m23)
-              "deep= matrix")
-
-      (assert (deep= (flipud m23)
-                     @[@[4 5 6] @[1 2 3]])
-              "flipud")
-
-      (assert (deep= (fliplr m23)
-                     @[@[3 2 1] @[6 5 4]])
-              "fliplr")
-
-      (assert (deep= (join-rows m3 m23)
-                     @[@[1 2 3]
-                       @[4 5 6]
-                       @[7 8 9]
-                       @[1 2 3]
-                       @[4 5 6]])
-              "join-rows")
-
-      (assert (deep= (join-cols m23 m23)
-                     @[@[1 2 3 1 2 3]
-                       @[4 5 6 4 5 6]])
-              "join-cols")
-
-      (assert (m-approx= (res1-m3 :Q)
-                         @[@[-0.123091490979333 -0.492365963917331 -0.861640436855329]
-                           @[-0.492365963917331 0.784145597779528 -0.377745203885826]
-                           @[-0.861640436855329 -0.377745203885826 0.338945893199805]])
-              "qr1-q")
-
-      (assert (m-approx= (res1-m3 :m^)
-                         @[@[-0.0859655700236277 -0.171931140047257]
-                           @[-0.90043974754135 -1.8008794950827]])
-              "qr1-m")
-
-      (assert (m-approx= (res-m3 :Q)
-                         @[@[-0.123091490979333 0.904534033733291 0.408248290463864]
-                           @[-0.492365963917331 0.301511344577765 -0.816496580927726]
-                           @[-0.861640436855329 -0.301511344577764 0.408248290463863]])
-              "qr-q")
-
-      (assert (m-approx= (res-m3 :R)
-                         @[@[-8.12403840463596 -9.60113629638795 -11.0782341881399]
-                           @[-8.88178419700125e-16 0.90453403373329 1.80906806746658]
-                           @[-8.88178419700125e-16 -4.44089209850063e-16 8.88178419700125e-16]])
-              "qr-r")
-
-      (assert (m-approx= U
-                         @[@[0.214837238368396 -0.887230688346371 0.408248290463863]
-                           @[0.520587389464737 -0.249643952988298 -0.816496580927726]
-                           @[0.826337540561078 0.387942782369775 0.408248290463863]])
-              "svd-U")
-
-      (assert (m-approx= S
-                         @[@[16.8481033526142 0 0]
-                           @[-1.1642042401554e-237 -1.06836951455471 0]
-                           @[-6.42285339593621e-323 0 3.62597321469472e-16]])
-
-              "svd-S")
-
-      (assert (m-approx= V
-                         @[@[0.479671177877771 -0.776690990321559 0.408248290463863]
-                           @[0.572367793972062 -0.0756864701045582 -0.816496580927726]
-                           @[0.665064410066353 0.625318050112442 0.408248290463863]])
-              "svd-U")
-
-      (assert (m-approx= (matmul m3 (ident (rows m3)))
-                         m3)
-              "matmul identity left")
-
-      (assert (m-approx= (matmul (ident (rows m3)) m3)
-                         m3)
-              "matmul identity right")
-
-      (assert (m-approx= m3 (matmul (res-m3 :Q) (res-m3 :R)))
-              "qr-square decompose")
-
-      (assert (m-approx= m23 (matmul (res-m23 :Q) (res-m23 :R)))
-              "qr-non-square decompose")
-
-      (assert (m-approx= m3 (reduce matmul (ident (rows U))
-                                    (array U S (trans V))))
-              "svd-USV' decompose")])
+      V (res-svd :V)]
+  (assert (deep= m23 m23)
+          "deep= matrix")
+
+  (assert (deep= (flipud m23)
+                 @[@[4 5 6] @[1 2 3]])
+          "flipud")
+
+  (assert (deep= (fliplr m23)
+                 @[@[3 2 1] @[6 5 4]])
+          "fliplr")
+
+  (assert (deep= (join-rows m3 m23)
+                 @[@[1 2 3]
+                   @[4 5 6]
+                   @[7 8 9]
+                   @[1 2 3]
+                   @[4 5 6]])
+          "join-rows")
+
+  (assert (deep= (join-cols m23 m23)
+                 @[@[1 2 3 1 2 3]
+                   @[4 5 6 4 5 6]])
+          "join-cols")
+
+  (assert (m-approx= (res1-m3 :Q)
+                     @[@[-0.123091490979333 -0.492365963917331 -0.861640436855329]
+                       @[-0.492365963917331 0.784145597779528 -0.377745203885826]
+                       @[-0.861640436855329 -0.377745203885826 0.338945893199805]])
+          "qr1-q")
+
+  (assert (m-approx= (res1-m3 :m^)
+                     @[@[-0.0859655700236277 -0.171931140047257]
+                       @[-0.90043974754135 -1.8008794950827]])
+          "qr1-m")
+
+  (assert (m-approx= (res-m3 :Q)
+                     @[@[-0.123091490979333 0.904534033733291 0.408248290463864]
+                       @[-0.492365963917331 0.301511344577765 -0.816496580927726]
+                       @[-0.861640436855329 -0.301511344577764 0.408248290463863]])
+          "qr-q")
+
+  (assert (m-approx= (res-m3 :R)
+                     @[@[-8.12403840463596 -9.60113629638795 -11.0782341881399]
+                       @[-8.88178419700125e-16 0.90453403373329 1.80906806746658]
+                       @[-8.88178419700125e-16 -4.44089209850063e-16 8.88178419700125e-16]])
+          "qr-r")
+
+  (assert (m-approx= U
+                     @[@[0.214837238368396 -0.887230688346371 0.408248290463863]
+                       @[0.520587389464737 -0.249643952988298 -0.816496580927726]
+                       @[0.826337540561078 0.387942782369775 0.408248290463863]])
+          "svd-U")
+
+  (assert (m-approx= S
+                     @[@[16.8481033526142 0 0]
+                       @[-1.1642042401554e-237 -1.06836951455471 0]
+                       @[-6.42285339593621e-323 0 3.62597321469472e-16]])
+
+          "svd-S")
+
+  (assert (m-approx= V
+                     @[@[0.479671177877771 -0.776690990321559 0.408248290463863]
+                       @[0.572367793972062 -0.0756864701045582 -0.816496580927726]
+                       @[0.665064410066353 0.625318050112442 0.408248290463863]])
+          "svd-U")
+
+  (assert (m-approx= (matmul m3 (ident (rows m3)))
+                     m3)
+          "matmul identity left")
+
+  (assert (m-approx= (matmul (ident (rows m3)) m3)
+                     m3)
+          "matmul identity right")
+
+  (assert (m-approx= m3 (matmul (res-m3 :Q) (res-m3 :R)))
+          "qr-square decompose")
+
+  (assert (m-approx= m23 (matmul (res-m23 :Q) (res-m23 :R)))
+          "qr-non-square decompose")
+
+  (assert (m-approx= m3 (reduce matmul (ident (rows U))
+                                (array U S (trans V))))
+          "svd-USV' decompose"))

 (assert (= 10 (perm @[@[1 2]
sogaiu commented 7 months ago

With the diff above, tests pass for me.


Sorry about the sop and binomial-distribution bits -- I thought I had only commented on lines that had changes on them.

Not sure what happened there (^^;

hyiltiz commented 7 months ago

Patch adopted. Separated workflow change into separate PR: https://github.com/janet-lang/spork/pull/178.

sogaiu commented 7 months ago

00f37604 passed all tests here :+1: