AlphaZeroIncubator / AlphaZero

Our implementation of AlphaZero for simple games such as Tic-Tac-Toe and Connect4.
0 stars 0 forks source link

Performance improvement #41

Open homerours opened 4 years ago

homerours commented 4 years ago

I did some short test for self-play on the server. I needed to modify few things, so I created a branch 'self_play_speed_test'. I used cProfiler for the profiling, here is the output of python -m cProfile -s cumulative speed_test.py (the result are sorted by cumulative time), for a TicTacToe game of 9 moves and 100 mcts run per move:

         1689398 function calls (1602525 primitive calls) in 11.643 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    754/1    0.004    0.000   11.645   11.645 {built-in method builtins.exec}
        1    0.000    0.000   11.645   11.645 speed_test.py:1(<module>)
        1    0.000    0.000    8.992    8.992 mcts.py:363(self_play)
        9    0.001    0.000    8.990    0.999 mcts.py:245(mcts)
      900    0.010    0.000    8.987    0.010 mcts.py:319(forward)
      900    0.023    0.000    8.866    0.010 mcts.py:287(backpropagate)
      900    0.099    0.000    7.841    0.009 mcts.py:83(calc_policy_value)
81000/900    0.436    0.000    7.473    0.008 module.py:540(__call__)
      900    0.012    0.000    7.466    0.008 model.py:218(forward)
      900    0.017    0.000    6.375    0.007 model.py:84(forward)

The self-play took about 9s, and 7.8s are used by the network evaluations (calc_policy_value) so that the MCTS logic takes the remaining 1.2s. It would be nice to check the results for Connect4, but I was unable to run it. Did you get similar results when running MCTS?

So, we need to improve the performance of the 'inference task' (and also MCTS if possible), using for instance multiple processes.

homerours commented 4 years ago

Here is a proposition (let me know what you think of it!) for parallelizing the self-play.

In order to keep it simple, we could say that each 'CPU worker' has its own MCTS tree. This avoid having to share a tree among multiple processes, which would require putting 'locks' everywhere.

The CPU workers will ask the GPU to evaluate some board positions. The GPU will be handled by another process, which received the boards in a Queue (= the bucket). Once the bucket contains enough board positions (100, 1000 ?) the GPU performs inference on this batch, and send then the results to the CPU workers (through another Queue).

However, we do not want our CPU to wait for the bucket to be filled. So, after a CPU worker send a request to the GPU, he continue to 'dive' into the tree for another position (without waiting the results from the GPU). In order not to end on the exact same position, we can use 'virtual loss' here. The CPU worker send therefore another position to evaluate to the GPU. Before he dives for a 3rd time, he checks his 'Queue' to see if the GPU has sent him back the results of the previous positions: if there is something in the Queue, he updates the tree (removing also the virtual loss), if not, he continues to 'dive' and look for a 3rd position to evaluate...

For an example of Processes communicating using Queues: https://github.com/AlphaZeroIncubator/AlphaZero/issues/24#issuecomment-644149207

homerours commented 4 years ago

To start, it would be simpler to implement a system where the CPU workers wait for the GPU, and then add the 'virtual loss stuff' if everything works.

homerours commented 4 years ago

In order to locate bottlenecks in the MCTS, I replaced the network evaluation by a random policy/value:

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    808/1    0.005    0.000    1.392    1.392 {built-in method builtins.exec}
        1    0.000    0.000    1.392    1.392 speed_test.py:1(<module>)
        1    0.000    0.000    0.789    0.789 mcts.py:366(self_play)
        9    0.001    0.000    0.787    0.087 mcts.py:248(mcts)
      900    0.010    0.000    0.781    0.001 mcts.py:322(forward)
       54    0.003    0.000    0.703    0.013 __init__.py:1(<module>)
      900    0.008    0.000    0.673    0.001 mcts.py:290(backpropagate)
        8    0.001    0.000    0.573    0.072 __init__.py:3(<module>)
    525/3    0.003    0.000    0.561    0.187 <frozen importlib._bootstrap>:986(_find_and_load)
    525/3    0.003    0.000    0.560    0.187 <frozen importlib._bootstrap>:956(_find_and_load_unlocked)
    499/3    0.003    0.000    0.560    0.187 <frozen importlib._bootstrap>:650(_load_unlocked)
    661/3    0.000    0.000    0.560    0.187 <frozen importlib._bootstrap>:211(_call_with_frames_removed)
    451/3    0.002    0.000    0.560    0.187 <frozen importlib._bootstrap_external>:777(exec_module)
  499/304    0.001    0.000    0.377    0.001 <frozen importlib._bootstrap>:549(module_from_spec)
    43/11    0.000    0.000    0.368    0.033 <frozen importlib._bootstrap_external>:1099(create_module)
    43/11    0.159    0.004    0.368    0.033 {built-in method _imp.create_dynamic}
      901    0.001    0.000    0.308    0.000 _game.py:268(is_game_over)
      902    0.005    0.000    0.307    0.000 _game.py:280(result)
      902    0.021    0.000    0.294    0.000 _game.py:239(get_game_status)
   388/24    0.001    0.000    0.278    0.012 {built-in method builtins.__import__}
      349    0.049    0.000    0.229    0.001 mcts.py:125(expand)
  488/249    0.002    0.000    0.213    0.001 <frozen importlib._bootstrap>:1017(_handle_fromlist)
    19729    0.010    0.000    0.203    0.000 tensor.py:25(wrapped)
    14266    0.149    0.000    0.149    0.000 {method 'eq' of 'torch._C._TensorBase' objects}
      900    0.027    0.000    0.129    0.000 mcts.py:83(calc_policy_value)
     1463    0.043    0.000    0.107    0.000 _game.py:179(board_after_move)
      451    0.017    0.000    0.104    0.000 <frozen importlib._bootstrap_external>:849(get_code)
     6913    0.029    0.000    0.074    0.000 {built-in method builtins.any}

While I do not know what the line "frozen importlib" refers to, it seems that the bottleneck is checking if the game is over or not. (which is quite intuitive)

guidopetri commented 4 years ago

I think "frozen importlib" refers to trying to import a library that was already previously imported. e.g.

for _ in range(10000):
    import pandas

would trigger that.

On Wed, Jul 8, 2020, 10:46 AM homerours notifications@github.com wrote:

In order to locate bottlenecks in the MCTS, I replaced the network evaluation by a random policy/value:

Ordered by: cumulative time

ncalls tottime percall cumtime percall filename:lineno(function) 808/1 0.005 0.000 1.392 1.392 {built-in method builtins.exec} 1 0.000 0.000 1.392 1.392 speed_test.py:1() 1 0.000 0.000 0.789 0.789 mcts.py:366(self_play) 9 0.001 0.000 0.787 0.087 mcts.py:248(mcts) 900 0.010 0.000 0.781 0.001 mcts.py:322(forward) 54 0.003 0.000 0.703 0.013 init.py:1() 900 0.008 0.000 0.673 0.001 mcts.py:290(backpropagate) 8 0.001 0.000 0.573 0.072 init.py:3() 525/3 0.003 0.000 0.561 0.187 :986(_find_and_load) 525/3 0.003 0.000 0.560 0.187 :956(_find_and_load_unlocked) 499/3 0.003 0.000 0.560 0.187 :650(_load_unlocked) 661/3 0.000 0.000 0.560 0.187 :211(_call_with_frames_removed) 451/3 0.002 0.000 0.560 0.187 :777(exec_module) 499/304 0.001 0.000 0.377 0.001 :549(module_from_spec) 43/11 0.000 0.000 0.368 0.033 :1099(create_module) 43/11 0.159 0.004 0.368 0.033 {built-in method _imp.create_dynamic} 901 0.001 0.000 0.308 0.000 _game.py:268(is_game_over) 902 0.005 0.000 0.307 0.000 _game.py:280(result) 902 0.021 0.000 0.294 0.000 _game.py:239(get_game_status) 388/24 0.001 0.000 0.278 0.012 {built-in method builtins.import} 349 0.049 0.000 0.229 0.001 mcts.py:125(expand) 488/249 0.002 0.000 0.213 0.001 :1017(_handle_fromlist) 19729 0.010 0.000 0.203 0.000 tensor.py:25(wrapped) 14266 0.149 0.000 0.149 0.000 {method 'eq' of 'torch._C._TensorBase' objects} 900 0.027 0.000 0.129 0.000 mcts.py:83(calc_policy_value) 1463 0.043 0.000 0.107 0.000 _game.py:179(board_after_move) 451 0.017 0.000 0.104 0.000 :849(get_code) 6913 0.029 0.000 0.074 0.000 {built-in method builtins.any}

While I do not know what the line "frozen importlib" refers to, it seems that the bottleneck is checking if the game is over or not. (which is quite intuitive)

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/AlphaZeroIncubator/AlphaZero/issues/41#issuecomment-655600768, or unsubscribe https://github.com/notifications/unsubscribe-auth/AEOFNOSH6NDJTLIDXBH3K5DR2SIDTANCNFSM4OUVD7CA .

homerours commented 4 years ago

oh... does this means that some libraries where imported ~500 times ?

guidopetri commented 4 years ago

I think so. I'm not sure about the above - i'm talking from memory, so I'd have to look it up to make sure.

But if so, then the fix is to move the imports to somewhere where they only execute once.

On Wed, Jul 8, 2020, 11:06 AM homerours notifications@github.com wrote:

oh... does this means that some libraries where imported ~500 times ?

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/AlphaZeroIncubator/AlphaZero/issues/41#issuecomment-655612326, or unsubscribe https://github.com/notifications/unsubscribe-auth/AEOFNOQU4F3GWI7JQTZEAT3R2SKRNANCNFSM4OUVD7CA .

PhilipEkfeldt commented 4 years ago

Thanks Leo, this is great. Regarding the import, I'm not sure where/how the import is run multiple times. I checked all the files and all imports are at the start of the files, outside of function definitions. As a test we could try running a file which only has the imports and nothing else.

abhon commented 4 years ago

Thanks for running the tests Leo!

It would be nice to check the results for Connect4, but I was unable to run it. Did you get similar results when running MCTS?

I haven't gotten to finishing the Connect4 yet, I think all that's left might be first for me to write a testing file for the class to make sure it works. As we were talking at our last meeting, I might leave the slightly more inefficient solution there for the meantime, seeing as a standard Connect4 Board isn't that large, and if it proves to be a bottleneck, I can write a slightly more efficient solution. Will try to get to work on it hopefully soon...

PhilipEkfeldt commented 4 years ago

To add, I will start looking into implementing your suggestion for parallelization.

homerours commented 4 years ago

Trying with an 'handcrafted' TicTacToe result checker (https://github.com/AlphaZeroIncubator/AlphaZero/blob/85a3a130b8a953d73578c3a54d3b0011475f0f5d/alphazero/_game.py#L285-L307) makes the total time go from ~0.3 s to ~0.07s (5 times faster). I will try to think on ways of making the one for Connect4 faster (an option could be to use Cython for this particular function...)

homerours commented 4 years ago

I haven't gotten to finishing the Connect4 yet, I think all that's left might be first for me to write a testing file for the class to make sure it works. As we were talking at our last meeting, I might leave the slightly more inefficient solution there for the meantime, seeing as a standard Connect4 Board isn't that large, and if it proves to be a bottleneck, I can write a slightly more efficient solution. Will try to get to work on it hopefully soon...

I added a function winning_move to the Connect4 class that checks if the last move is winning: 577f287 (I took the freedom to commit twice on the Connct4 branch, I hope it's ok for you).

I also had a concern about the get_legal_moves function: https://github.com/AlphaZeroIncubator/AlphaZero/blob/577f287aeee0f15c4d6e6d64daf7673ac3b75d1b/alphazero/_game.py#L465

Should'nt this function return a list of integers corresponding of the columns that are not full ? That is do something like:

return torch.where( board[0,:] == -1)[0]

I had a similar concern with the dimension of the PolicyHead of the network, that might not necessarily be of the dimensions of the board: https://github.com/AlphaZeroIncubator/AlphaZero/issues/37#issuecomment-655551837

homerours commented 4 years ago

I did some test with Cython for the function that checks if a given move is winning (at connect4) here.

Result: using Cython leads to 15x to 20x speedup :smile_cat: (And I am new to Cython, one can possibly improve on that). Hence, if the bottleneck is still the game result checking, one could possibly compile some critical functions in Cython (which is very easy to do).

guidopetri commented 4 years ago

That's amazing. Thanks so much Leo :)

On Thu, Jul 9, 2020, 10:21 AM homerours notifications@github.com wrote:

I did some test with Cython for the function that checks if a given move is winning (at connect4) here https://github.com/AlphaZeroIncubator/AlphaZero/tree/self_play_speed_test/alphazero/speed_test/demo_Cython .

Result: using Cython leads to 15x to 20x speedup 😸 (And I am new to Cython, one can possibly improve on that). Hence, if the bottleneck is still the game result checking, one could possibly compile some critical functions in Cython (which is very easy to do).

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/AlphaZeroIncubator/AlphaZero/issues/41#issuecomment-656191521, or unsubscribe https://github.com/notifications/unsubscribe-auth/AEOFNOTJE4W342CWZK5IUSLR2XN77ANCNFSM4OUVD7CA .

abhon commented 4 years ago

Thank you so much Leo! I've never heard of Cython so something new to look into.

abhon commented 4 years ago

I also had a concern about the get_legal_moves function: https://github.com/AlphaZeroIncubator/AlphaZero/blob/577f287aeee0f15c4d6e6d64daf7673ac3b75d1b/alphazero/_game.py#L465

Should'nt this function return a list of integers corresponding of the columns that are not full ? That is do something like:

return torch.where( board[0,:] == -1)[0]

I had a similar concern with the dimension of the PolicyHead of the network, that might not necessarily be of the dimensions of the board: #37 (comment)

Fixed this, but rather than returning the torch.where(), I'm thinking just:

return board[0,:] == -1 

since in the TTT Sid wrote, we are returning a boolean tensor of the legal moves.