AlphaZeroIncubator / AlphaZero

Our implementation of AlphaZero for simple games such as Tic-Tac-Toe and Connect4.
0 stars 0 forks source link

Remark on MCTS #33

Open homerours opened 4 years ago

homerours commented 4 years ago

Hi! I just went over mcts.py. Here are few remarks:

To do so, we could store the prior, visit counts and Q values (total_value) in arrays attached to parent nodes:

class MCTSNode:
    childrens = [node1, node2, node3, node4]
    policy = [ p1, p2, p3, p4]
    n_visit = [ n1, n2, n3, n4]
    total_values = [q1, q2, q3, q4]
...

Then,

index_chosen_child = np.argmax( self.total_values 
                   + c * policy * np.sqrt(np.sum(self.n_visit)) / (1 + self.total_values))

I do not know if this would really speedup things, so that may not be urgent... A drawback is that we would need to keep for each node an index indicating the position of the node in the arrays of its parent (because we need to update the total_values and n_visit when backpropagating).

PhilipEkfeldt commented 4 years ago

Thanks for the feedback!

  • The MCTSNode attributes prior and policy seems redundant: node.prior = node.parent._policy['child_index']

Good point, although we'd need to know store the index in that case, or use the action tensor to get the index. Not sure if that improves the performance.

  • The root_player attribute may lead to complication: after playing a move, you would need to update all the root_player attributes of the corresponding subtree. I do not believe that we really need this attribute, if we adopt a convention like 'value(position) = value for the player who has to play' (this is the convention used by the NN). Then, we should modify backpropagation:
        current = leaf_node
        sign = 1
        while not current.is_root:
            current.n_visit += 1
            current.total_value += sign*v_leaf
            sign *= -1
            current = current.parent

Here, we however need to be carefull when chosing the child with highest PUCT: with the above convention, the Q-value of a child of the root correspond to the point of view of the oponent of the root. Hence we should probably use -Q to compute the PUCT.

Yup, I will fix this now that I'm working with the network implementation.

  • (above we could use something like current.parent.n_visit)

You're right, they will always be the same

To do so, we could store the prior, visit counts and Q values (total_value) in arrays attached to parent nodes:

class MCTSNode:
    childrens = [node1, node2, node3, node4]
    policy = [ p1, p2, p3, p4]
    n_visit = [ n1, n2, n3, n4]
    total_values = [q1, q2, q3, q4]
...

Then,

index_chosen_child = np.argmax( self.total_values 
                   + c * policy * np.sqrt(np.sum(self.n_visit)) / (1 + self.total_values))

I do not know if this would really speedup things, so that may not be urgent... A drawback is that we would need to keep for each node an index indicating the position of the node in the arrays of its parent (because we need to update the total_values and n_visit when backpropagating).

This is definitely a good point, although I don't know if we need to change it now. As you say, it means we need to keep track of the index of the child when backpropagating.

PhilipEkfeldt commented 4 years ago

Just came to think of, for the vectorization you mention, we could use dicts on the parent with the key being the action index and store the action index instead of the action on the node, or both. That makes it easy to find the right one, and we can still do vectorization with dict.values() (assuming things stay in the same order, which I think is true for dicts if they are inserted in the same order)

guidopetri commented 4 years ago

You can use an OrderedDict if you need order, but don't rely on insertion order (it's not python spec)

On Wed, Jun 24, 2020, 10:27 PM Philip Ekfeldt notifications@github.com wrote:

Just came to think of, for the vectorization you mention, we could use dicts on the parent with the key being the action index and store the action index instead of the action on the node, or both. That makes it easy to find the right one, and we can still do vectorization with dict.values() (assuming things stay in the same order, which I think is true for dicts if they are inserted in the same order)

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/AlphaZeroIncubator/AlphaZero/issues/33#issuecomment-649192961, or unsubscribe https://github.com/notifications/unsubscribe-auth/AEOFNOQN4PXVHQUH7XPJ2MDRYK73TANCNFSM4OFPZKFA .