kevaday / alphazero-general

A fast, generalized, and modified implementation of Deepmind's distinguished AlphaZero in PyTorch.
MIT License
66 stars 21 forks source link

AttributeError: 'function' object has no attribute 'supports_process' #8

Open albert88as opened 2 years ago

albert88as commented 2 years ago

Dear Kevi Aday, Thanks for public your code. There some error when I trying to train connect4 game. Can you help me to solve this error. Thank you in advance! Albert,

PITTING AGAINST BASELINE: RawMCTSPlayer Traceback (most recent call last): File "D:\ana3\envs\muzero\lib\runpy.py", line 193, in _run_module_as_main "__main__", mod_spec) File "D:\ana3\envs\muzero\lib\runpy.py", line 85, in _run_code exec(code, run_globals) File "D:\88Projects\alphazero-general\alphazero\envs\connect4\train.py", line 58, in <module> c.learn() File "D:\88Projects\alphazero-general\alphazero\Coach.py", line 267, in learn self.compareToBaseline(self.model_iter) File "D:\88Projects\alphazero-general\alphazero\Coach.py", line 148, in wrapper ret = func(self, *args, **kwargs) File "D:\88Projects\alphazero-general\alphazero\Coach.py", line 589, in compareToBaseline self.arena = Arena(players, self.game_cls, use_batched_mcts=can_process, args=self.args) File "alphazero\Arena.pyx", line 66, in alphazero.Arena._set_state.decorator.wrapper ret = func(self, *args, **kwargs) File "alphazero\Arena.pyx", line 108, in alphazero.Arena.Arena.__init__ self.players = players File "alphazero\Arena.pyx", line 129, in alphazero.Arena.Arena.players self.__check_players_valid() File "alphazero\Arena.pyx", line 132, in genexpr if self.use_batched_mcts and not all(p.player.supports_process() for p in self.players): File "alphazero\Arena.pyx", line 132, in genexpr if self.use_batched_mcts and not all(p.player.supports_process() for p in self.players): AttributeError: 'function' object has no attribute 'supports_process'

vasilije2448 commented 2 years ago

Creating a wrapper class seems to solve the issue. I just found this repo so I'm not sure if it breaks anything else.

class ProcessWrapper:
    def __init__(self, process):
        self.process = process

    def process(self):
        return self.process

    def supports_process(self):
        return True

in alphazero/Coach.py compareToPast use

nplayer = ProcessWrapper(self.train_net.process)
pplayer = ProcessWrapper(self.self_play_net.process)

instead of

nplayer = self.train_net.process
pplayer = self.self_play_net.process

and similarly in compareToBaseline.