lifthrasiir / roadroller

Roadroller: Flattens Your JavaScript Demo
https://lifthrasiir.github.io/roadroller/
Other
326 stars 12 forks source link

Alternative activation function #3

Open lifthrasiir opened 3 years ago

lifthrasiir commented 3 years ago

Context mixing commonly uses the logistic function f(x) = 1/(1+exp(-x)) as an activiation function, but it is not the only possibility. Since Math.log/exp is quite expensive and given their accuracy is implementation-defined (although all known browsers use the same approximation), it is worthwhile to see if an alternative activation function can be used.

I've personally tried two possibilities g(x) = x/sqrt(1+x^2) and h(x) = x/(1+abs(x)). (Note that IEEE 754-2008 requires sqrt to be correctly rounded.) They are roughly similar to the original logistic function after scaling:

image

g(x) turned out to be a drop-in for the logistic function, while h(x) required an adjustment to the learning rate (by about 2-10). As an example the following is a required change against 3f60e44b71702c69ae09ca426466933a72a7ad2c for g(x).

         for (let i = 0; i < this.models.length; ++i) {
             const weight = this.weights[i];
-            const prob = this.models[i].predict(context) * 2 + 1;
-            const stretchedProb = Math.log(prob / ((2 << this.precision) - prob));
+            const prob = (this.models[i].predict(context) * 2 + 1) / (1 << this.precision) - 1;
+            const stretchedProb = prob / Math.sqrt(1 - prob * prob);
             this.stretchedProbs[i] = stretchedProb;
             total += weight * stretchedProb;
         }

         // since CM is the last model and predictions are not stored,
         // we can just compute the external probability directly.
-        const mixedProb = (2 << this.precision) / (1 + Math.exp(-total));
+        const mixedProb = (total / Math.sqrt(1 + total * total) + 1) * (1 << this.precision);
         if (mixedProb >= (2 << this.precision)) {
             throw new Error('LogisticMixModel.predict: weighted average overflow');
         }
                 // calculate the mixed prediction q
                 `x=u.map((j,i)=>(` +
-                    `y=p[j]*2+1,` +
+                    `y=(p[j]*4+2)/M-1,` +
                     // stretch(prob), needed for updates
-                    `y=Math.log(y/(M-y)),` +
-                    `q-=w[i]*y,` +
+                    `y/=Math.sqrt(1-y*y),` +
+                    `q+=w[i]*y,` +
                     // premultiply with learning rate
                     `y${learningRateNum == 1 ? '' : '*'+learningRateNum}/${learningRateDenom}` +
                 `)),` +

                 // q: squash(sum of weighted preds) followed by adjustment
-                `q=M/(1+Math.exp(q))|1,` +
+                `q=(q/Math.sqrt(1+q*q)+1)*M/2|1,` +
                 // decode the bit b
                 `b=t%M<q,` +
                 `t=(b?q:M-q)*(t>>${precision + 1})+t%M-!b*q` +

Unfortunately both were not enough for replacing the logistic function, g(x) was 1--400 bytes larger while h(x) was 1000+ bytes larger for most samples I've tried. Really unfortunate, as g(x) were significantly faster than f(x) in V8 (by about 10%). Any other suggestions are welcomed.

lifthrasiir commented 2 years ago

As of 2.1.0 it still remains true that g(x) and h(x) performs much worse than f(x) even after -O2.