aqtech-ca / mctreesearch4j

MIT License
7 stars 1 forks source link

@JunTaoLuo im currently unable to do this unless i heavily modify the `State{less/ful}Solver` because the variables, namely `root` and `calculateUCT` class methods are `private`, if we change their visibility to `open protected` then their subclasses should get access to those variables. The code below should work if we make this edit codebase wise #17

Closed larkz closed 3 years ago

larkz commented 3 years ago

Not an ideal thing, but we should make extendable solver classes.

@JunTaoLuo im currently unable to do this unless i heavily modify the State{less/ful}Solver because the variables, namely root and calculateUCT class methods are private, if we change their visibility to open protected then their subclasses should get access to those variables. The code below should work if we make this edit codebase wise

package MDPSolver

import MDP
import kotlin.random.Random
import StatelessSolver
import kotlin.math.ln
import kotlin.math.sqrt

class StatelessSolverTracked<TState, TAction>(
        private val mdp: MDP<TState, TAction>,
        private val random: Random,
        private val iterations: Int,
        private val simulationDepthLimit: Int,
        private val explorationConstant: Double,
        private val rewardDiscountFactor: Double,
        private val verbose: Boolean): StatelessSolver<TState, TAction>(
        mdp,
        random,
        iterations,
        simulationDepthLimit,
        explorationConstant,
        rewardDiscountFactor,
        verbose) {

    fun buildTreTracked(): MutableList<Double> {
        initialize()

        var rewardTracker = mutableListOf<Double>()

        for (i in 0..iterations) {
            iterateStep()

            // println(root!!.children.toString())
            // println(root!!.children.maxByOrNull { a -> calculateUCT(a)} )
            var ns = root!!.children.maxByOrNull { a -> calculateUCT(a)}!!.n
            var explorationFactor = explorationConstant*sqrt(ln(i.toDouble())/ns)
            rewardTracker.add(explorationFactor)

        }
        return rewardTracker
    }

}

Otherwise, should we keep the original implementation, and use some flag to trigger the rewardTracker tracking?

_Originally posted by @larkz in https://github.com/JunTaoLuo/KotlinMCTS/pull/14#discussion_r579719975_