o-jill / ruversi

reversi program.
https://o-jill.github.io/ruversi/
1 stars 0 forks source link

活性化関数 #33

Open o-jill opened 2 years ago

o-jill commented 2 years ago

今はシンプルにシグモイドだが、計算を簡単にして処理速度を上げるためにsoftsign(y = x / (1+|x|))とかにしちゃうのはどうか? |x|ならsimd化も簡単にできるかもよ。

rustはどのexp(x)を使ってんの?独自?gcc? gcc版はvcとかと比べるとちょっと遅いらしい。

exp_ps(x)みたいなやつは自分で作ってね。 世の中にはいっぱいありそう。 crate無いの?

o-jill commented 2 years ago

sigmoidよりsoftsignのほうが早く学習が収束してエラー率もsoftsignのほうがいいんだって。 https://club.informatix.co.jp/?p=11274

o-jill commented 2 years ago
fn softsign(x : f32) -> f32 {
    x / (1 + abs(x))
}

// y = x / (1 + |x|)
fn softsign_v4(x : *const f32, y : *mut f32) {
    unsafe {
    let x4 = x86_64::_mm_load_ps(x);
    let signmask = x86_64::_mm_set1_ps(-0.0);
    let absx4 = x86_64::_mm_andnot_ps(signmask, x4);
    let one = x86_64::_mm_set1_ps(1.0);
    let axp1 =x86_64::_mm_add_ps(one, absx4);
    let res = x86_64::_mm_ div_ps(x4, axp1);
    x86_64::_mm_store_ps(y, res);
    }
}

// y = 0.5 * x / (1 + |x|) + 0.5
fn softsign01_v4(x : *const f32, y : *mut f32) {
    unsafe {
    let _05 = x86_64::_mm_set1_ps(0.5);
    let x4 = x86_64::_mm_load_ps(x);
    let signmask = x86_64::_mm_set1_ps(-0.0);
    let absx4 = x86_64::_mm_andnot_ps(signmask, x4);
    let one = x86_64::_mm_set1_ps(1.0);
    let axp1 =x86_64::_mm_add_ps(one, absx4);
    let y_11 = x86_64::_mm_div_ps(x4, axp1);
    let y05 = x86_64::_mm_mul_ps(y_11, _05);
    let res = x86_64::_mm_add_ps(y05, _05);
    x86_64::_mm_store_ps(y, res);
    }
}
o-jill commented 1 year ago

大体実行時間が sigmoid() : softsign() : softsign01() = 28 : 11 : 12

use rand::Rng;

// 0 ~ 1
fn sigmoid(x : f32) -> f32 {
    1.0 / (1.0 + (-x).exp())
}

// -1 ~ 1
fn softsign(x : f32) -> f32 {
    x / (1.0 + x.abs())
}

// 0 ~ 1
fn softsign01(x : f32) -> f32 {
    0.5 * x / (1.0 + x.abs()) + 0.5
}

fn main() {
    println!("Hello, world!");
    let mut v = [0.0 as f32 ; 300];
    let mut u = [0.0 as f32 ; 300];
    let mut rng = rand::thread_rng();

    let now = std::time::Instant::now();
    for _i in 0..1000000 {
        for j in 0..200 {
            let x = rng.gen::<f64>() * 10.0;
            let x = x as f32;
            v[j] = sigmoid(x);
            let xx = 1.0 / x;
            u[j] = sigmoid(xx);
        }
    }
    println!("{}", now.elapsed().as_millis());

    let now = std::time::Instant::now();
    for _i in 0..1000000 {
        for j in 0..200 {
            let x = rng.gen::<f64>() * 10.0;
            let x = x as f32;
            v[j] = softsign(x);
            let xx = 1.0 / x;
            u[j] = softsign(xx);
        }
    }
    println!("{}", now.elapsed().as_millis());

    let now = std::time::Instant::now();
    for _i in 0..1000000 {
        for j in 0..200 {
            let x = rng.gen::<f64>() * 10.0;
            let x = x as f32;
            v[j] = softsign01(x);
            let xx = 1.0 / x;
            u[j] = softsign01(xx);
        }
    }
    println!("{}", now.elapsed().as_millis());
    for j in 0..50 {
        print!("{}", v[j]);
    }
    println!("");
    for j in 0..50 {
        print!("{}", u[j]);
    }
}
o-jill commented 1 year ago

なんとなくシグモイドどうこうよりも入力のパターンを沢山学習/記憶出来る方がいい気がしてきたのでsoftsignにして速度を上げて中間層を増やすのがいい気がしてきた。

o-jill commented 1 year ago

初期局面からdepth15で計測。simd無し。 sigmoid: 1.1Mnps softsign: 1.3Mnps simdが入るとsoftsignがもう少し速くなるかも? sigmoid-simd: 2.2Mnps softsign-simd: 2.5Mnps simdの影響でexp()の計算回数が減ってちょっと近づいた? 20%弱は速くなりそう。

o-jill commented 1 year ago

学習速度への影響?

o-jill commented 1 year ago

探索の速度アップの効果について、

79 から、実行時間の約1/4が評価に費やしている時間なので、

simd版で、評価1回の時間は50%(=(2.5/2.2-0.75)/0.25)の速度アップってことになりそう。

o-jill commented 1 year ago

8687kifuを5回学習、棋譜ファイルキャッシュ後 softsign: nosimd: 6.8sec, simd: 2.6sec sigmoid: nosimd: 7.5sec, simd: 2.9sec

約10%高速化になる。 softsignの計算はカスケード処理なのでココもSIMDにするともうちょっと速くなるかもしれない。 けど中間層の数が少ないので効果は限定的かも。

o-jill commented 1 year ago

↑はforward()をミスってたけど直した後も実行時間は変わらなかった。 testでnosimdとsimdの一致を見るようにした。sigmoidと違って基本的におなじになってくれるっぽい?それともsigmoidのやつがバグっている?

o-jill commented 1 year ago

せっかくなのでtraitの勉強しますか?

o-jill commented 1 year ago

こんな感じ?

o-jill commented 8 months ago

wikipedia見てたらf(x) = x / sqrt(x^2+1)ってのもシグモイドっぽいかたちになるらしい。 https://ja.wikipedia.org/wiki/%E6%B4%BB%E6%80%A7%E5%8C%96%E9%96%A2%E6%95%B0 f'(x) = 1/sqrt(x^2+1)+ 2x^2/sqrt(x^2+1)^3 f'(x) = (x^2 + 1 + 2x^2)/sqrt(x^2+1)^3 f'(x) = (3x^2 + 1)/sqrt(x^2+1)^3 で合ってる?

o-jill commented 8 months ago

hard sigmoidってのがあるらしいっすよ。 https://qiita.com/kuroitu/items/73cd401afd463a78115a#hardsigmoid%E9%96%A2%E6%95%B0 https://qiita.com/hsjoihs/items/88d1569aaef01659bbd5

入力をa, -aでクリップ、 xで計算。(2ndのときはプラス用とマイナス用両方計算して片方を使う) 割り算はない(事前に小数に変換可能) N次で近似可能?4次ぐらいまで? image

xの次数とa https://gist.github.com/o-jill/2f7fa9ac787be187f471539f9b9bd308

o-jill commented 8 months ago

hard_sigmoidをsseしてみた。 https://gist.github.com/o-jill/7bfa8fc7497adbc8349057003f1451fd 精度的に16乗ぐらいまで(x:20)ぐらいまでが限界っぽい。 特にx>0側、1からかなり小さい数字を引くことになるので1になってしまう。

o-jill commented 8 months ago

hardsigmoid時間計測 時間的には8次近似ぐらいまでかな。

N=100
- sigmoid:
speed: 3330.02 nodes/msec
2358853 nodes / 708.36 ± 26.32 msec (663 -- 801)
speed: 3316.43 nodes/msec
2265722 nodes / 683.18 ± 29.62 msec (637 -- 763)
speed: 3515.18 nodes/msec
93293 nodes / 26.54 ± 4.62 msec (24 -- 47)

duration: 17.215 sec.
2.15 sec/game = 17.215 / 8

- hardsigmoid 2:
speed: 3549.70 nodes/msec
2328605 nodes / 656.00 ± 30.38 msec (613 -- 750)
speed: 3554.14 nodes/msec
2254714 nodes / 634.39 ± 26.53 msec (590 -- 712)
speed: 3876.58 nodes/msec
92379 nodes / 23.83 ± 1.89 msec (23 -- 37)

duration: 15.313 sec.
1.91 sec/game = 15.313 / 8

- hardsigmoid 4:
speed: 3504.26 nodes/msec
2373718 nodes / 677.38 ± 21.96 msec (635 -- 743)
speed: 3445.72 nodes/msec
2248746 nodes / 652.62 ± 30.15 msec (602 -- 758)
speed: 3635.91 nodes/msec
92534 nodes / 25.45 ± 5.28 msec (23 -- 54)

duration: 16.921 sec.
2.12 sec/game = 16.921 / 8

- hardsigmoid 8:
speed: 3482.84 nodes/msec
2392433 nodes / 686.92 ± 27.65 msec (648 -- 789)
speed: 3491.20 nodes/msec
2243446 nodes / 642.60 ± 28.68 msec (602 -- 755)
speed: 3651.39 nodes/msec
92928 nodes / 25.45 ± 5.35 msec (23 -- 53)

duration: 17.354 sec.
2.17 sec/game = 17.354 / 8

speed: 3483.25 nodes/msec
2392433 nodes / 686.84 ± 33.42 msec (632 -- 789)
speed: 3462.05 nodes/msec
2243446 nodes / 648.01 ± 32.28 msec (600 -- 812)
speed: 3575.53 nodes/msec
92928 nodes / 25.99 ± 5.32 msec (23 -- 53)

speed: 3497.86 nodes/msec
2392433 nodes / 683.97 ± 25.71 msec (640 -- 752)
speed: 3510.98 nodes/msec
2243446 nodes / 638.98 ± 27.40 msec (597 -- 717)
speed: 3715.63 nodes/msec
92928 nodes / 25.01 ± 4.13 msec (23 -- 49)
o-jill commented 7 months ago

f(x) = 1 - 0.5 * ((a-x)/a)^Nを微分するとf'(x) = -0.5 * N/a * ((a-x)/a)^(N-1) f(x) = 0.5 * ((a+x)/a)^Nを微分するとf'(x) = 0.5 * N/a * ((a+x)/a)^(N-1)