nextzlog / todo

ToDo lists for ATS-4, CW4ISR, QxSL, ZyLO.
https://nextzlog.dev
1 stars 0 forks source link

VB-EMアルゴリズムによる速度判定 #180

Closed JG1VPP closed 11 months ago

JG1VPP commented 1 year ago

問題意識

現状のCW4ISRの速度判定はk-meansで実現しているが、外れ値や速度変化に脆弱。

解決方法

変分ベイズ推定により、速度変化やノイズに対する頑健性を高める。ただし、速度変化に追随するには別の方法が必要。

JG1VPP commented 1 year ago

Scalaで実装

Scalaで実装するパターン認識と機械学習より:

class Kmeans(x: Seq[Seq[Double]], k: Int, epochs: Int = 100) {
    val mu = Array.fill(k, x.map(_.size).min)(math.random)
    def apply(x: Seq[Double]) = mu.map(quads(x)(_).sum).zipWithIndex.minBy(_._1)._2
    def quads(a: Seq[Double])(b: Seq[Double]) = a.zip(b).map(_-_).map(d=> d * d)
    def estep = x.groupBy(apply).values.map(c=> c.transpose.map(_.sum / c.size))
    for(epoch <- 0 until epochs) estep.zip(mu).foreach((e,m)=> e.copyToArray(m))
}

class GMM(val d: Int, val k: Int) {
    val w = Array.fill(k)(1.0 / k)
    val m -> s = (Array.fill(k, d)(math.random), Array.fill(k, d)(math.random))
    def apply(x: Seq[Double]) = w.lazyZip(m).lazyZip(s).map(Normal(x)(_,_,_).p)
}

case class Normal(x: Seq[Double])(w: Double, m: Seq[Double], s: Seq[Double]) {
    def n = math.exp(-0.5 * x.zip(m).map(_-_).map(d=>d*d).zip(s).map(_/_).sum)
    def p = w * n / math.pow(2 * math.Pi, 0.5 * x.size) / math.sqrt(s.product)
}

class EM(val x: Seq[Seq[Double]], val mm: GMM, epochs: Int = 100) {
    def mstep(P: Seq[Seq[Double]]) = {
        P.map(_.sum / x.size).copyToArray(mm.w)
        val m = P.map(_.zip(x).map((p,x) => x.map(x => p * x)).transpose.map(_.sum))
        val s = P.map(_.zip(x).map((p,x) => x.map(x => p*x*x)).transpose.map(_.sum))
        m.zip(P).map((m,p) => m.map(_ / p.sum)).zip(mm.m).foreach(_.copyToArray(_))
        s.zip(P).map((m,p) => m.map(_ / p.sum)).zip(mm.s).foreach(_.copyToArray(_))
        for((s,m) <- mm.s.zip(mm.m); d <- 0 until mm.d) s(d) -= m(d) * m(d)
    }
    for(epoch <- 1 to epochs) mstep(x.map(mm(_)).map(p=>p.map(_/p.sum)).transpose)
}

class VB(val x: Seq[Seq[Double]], val mm: GMM, epochs: Int = 1000, W: Double = 1) {
    val n = Array.fill(mm.k)(1.0 / mm.k)
    val w -> m = (Array.fill(mm.k, mm.d)(W), Array.fill(mm.k, mm.d)(math.random))
    for(epoch <- 1 to epochs) new MstepGMM(this, mm, new EstepGMM(this, mm).post)
}

class EstepGMM(vb: VB, mm: GMM) {
    val eq35 = vb.n.map(Digamma).map(_-Digamma(vb.n.sum))
    val eq3A = vb.n.map(n=>0.to(mm.d-1).map(d=>(n-d)/2).map(Digamma))
    val eq36 = eq3A.zip(vb.w).map(_.sum-_.map(math.log).sum).map(_/2)
    def wish = vb.x.toArray.map(_.toArray).map(vb.m-_).map(d=>d.mul(d).div(vb.w))
    def eq34 = wish.map(_.zip(vb.n).map(-_.sum/2*_))+eq35+eq36-vb.n.map(mm.d/_/2)
    def post = eq34.map(_.map(math.exp)).map(x=>x.map(_/x.sum)).toSeq.transpose
}

class MstepGMM(vb: VB, mm: GMM, post: Seq[Seq[Double]]) {
    new EM(vb.x, mm, 0).mstep(post)
    val eq11 = post.map(_.sum).toArray
    val eq38 = vb.n.zip(eq11).map(_+_)
    val eq39 = vb.m.mul(vb.n).div(eq38).add(mm.m.mul(eq11).div(eq38))
    val eq41 = vb.m.mul(vb.m).mul(vb.n).sub(eq39.mul(eq39).mul(eq38))
    val eq40 = mm.s.add(mm.m.mul(mm.m)).mul(eq11).add(vb.w.add(eq41))
    eq38.copyToArray(vb.n)
    eq39.zip(vb.m).foreach(_.copyToArray(_))
    eq40.zip(vb.w).foreach(_.copyToArray(_))
}

implicit class Vector(x: Array[Array[Double]]) {
    def +(y: Array[Double]) = x.map(_.zip(y).map(_+_))
    def -(y: Array[Double]) = x.map(_.zip(y).map(_-_))
    def add(y: Array[Double]) = x.zip(y).map((x,y) => x.map(_+y))
    def sub(y: Array[Double]) = x.zip(y).map((x,y) => x.map(_-y))
    def mul(y: Array[Double]) = x.zip(y).map((x,y) => x.map(_*y))
    def div(y: Array[Double]) = x.zip(y).map((x,y) => x.map(_/y))
    def add(y: Array[Array[Double]]) = x.zip(y).map(_.zip(_).map(_+_))
    def sub(y: Array[Array[Double]]) = x.zip(y).map(_.zip(_).map(_-_))
    def mul(y: Array[Array[Double]]) = x.zip(y).map(_.zip(_).map(_*_))
    def div(y: Array[Array[Double]]) = x.zip(y).map(_.zip(_).map(_/_))
}

object Digamma extends Function[Double, Double] {
    def apply(x: Double): Double = {
        var index -> value = (x, 0.0)
        def d = 1.0 / (index * index)
        while(index < 49) (value -= 1 / index, index += 1)
        val s = d * (1.0 / 12 - d * (1.0 / 120 - d / 252))
        (value + math.log(index) - 0.5 / index - s)
    }
}
JG1VPP commented 11 months ago

k-meansと比べて収束性が悪化した。凍結。