pytorch / ELF

ELF: a platform for game research with AlphaGoZero/AlphaZero reimplementation
3.37k stars 567 forks source link

Confusion about the updateEdgeStats() #166

Open neoql opened 4 years ago

neoql commented 4 years ago

In the updateEdgeStats function, reward is updated by edge.reward += reward, which is consistent with the formula in paper "Mastering the game of Go without human knowledge".

But in many other popular unofficial implementations, e.g. ![junxiaosong/AlphaZero_Gomoku](), ![suragnair/alpha-zero-general](), add v to update the edge reward when the current node belongs to the current player, but add -v when the current node belongs to other player. These implementation has achieved good results even if the reward update method is different from the description in the original paper.

I think this implementation is more intuitive than the method described in the original paper, and the Q value of each node represents the value of the current node for the player. 1. I don’t understand why Q in the original paper description is the average of all v in the subtree, no matter v is reward for which player.

And 2. Why are both methods effective? So what are the differences between them?

The following code is taken from ![junxiaosong/AlphaZero_Gomoku]() and ![suragnair/alpha-zero-general]() respectively.

def search(self, canonicalBoard):
        This function performs one iteration of MCTS. It is recursively called
        till a leaf node is found. The action chosen at each node is one that
        has the maximum upper confidence bound as in the paper.
        Once a leaf node is found, the neural network is called to return an
        initial policy P and a value v for the state. This value is propagated
        up the search path. In case the leaf node is a terminal state, the
        outcome is propagated up the search path. The values of Ns, Nsa, Qsa are
        NOTE: the return values are the negative of the value of the current
        state. This is done since v is in [-1,1] and if v is the value of a
        state for the current player, then its value is -v for the other player.
            v: the negative of the value of the current canonicalBoard

        s =

        if s not in self.Es:
            self.Es[s] =, 1)
        if self.Es[s]!=0:
            # terminal node
            return -self.Es[s]

        if s not in self.Ps:
            # leaf node
            self.Ps[s], v = self.nnet.predict(canonicalBoard)
            valids =, 1)
            self.Ps[s] = self.Ps[s]*valids      # masking invalid moves
            sum_Ps_s = np.sum(self.Ps[s])
            if sum_Ps_s > 0:
                self.Ps[s] /= sum_Ps_s    # renormalize
                # if all valid moves were masked make all valid moves equally probable

                # NB! All valid moves may be masked if either your NNet architecture is insufficient or you've get overfitting or something else.
                # If you have got dozens or hundreds of these messages you should pay attention to your NNet and/or training process.   
                print("All valid moves were masked, do workaround.")
                self.Ps[s] = self.Ps[s] + valids
                self.Ps[s] /= np.sum(self.Ps[s])

            self.Vs[s] = valids
            self.Ns[s] = 0
            return -v

        valids = self.Vs[s]
        cur_best = -float('inf')
        best_act = -1

        # pick the action with the highest upper confidence bound
        for a in range(
            if valids[a]:
                if (s,a) in self.Qsa:
                    u = self.Qsa[(s,a)] + self.args.cpuct*self.Ps[s][a]*math.sqrt(self.Ns[s])/(1+self.Nsa[(s,a)])
                    u = self.args.cpuct*self.Ps[s][a]*math.sqrt(self.Ns[s] + EPS)     # Q = 0 ?

                if u > cur_best:
                    cur_best = u
                    best_act = a

        a = best_act
        next_s, next_player =, 1, a)
        next_s =, next_player)

        v =

        if (s,a) in self.Qsa:
            self.Qsa[(s,a)] = (self.Nsa[(s,a)]*self.Qsa[(s,a)] + v)/(self.Nsa[(s,a)]+1)
            self.Nsa[(s,a)] += 1

            self.Qsa[(s,a)] = v
            self.Nsa[(s,a)] = 1

        self.Ns[s] += 1
        return -v

def update_recursive(self, leaf_value):
        """Like a call to update(), but applied recursively for all ancestors.
        # If it is not root, this node's parent should be updated first.
        if self._parent: