Open homerours opened 4 years ago
Here is a proposition (let me know what you think of it!) for parallelizing the self-play.
In order to keep it simple, we could say that each 'CPU worker' has its own MCTS tree. This avoid having to share a tree among multiple processes, which would require putting 'locks' everywhere.
The CPU workers will ask the GPU to evaluate some board positions. The GPU will be handled by another process, which received the boards in a Queue
(= the bucket). Once the bucket contains enough board positions (100, 1000 ?) the GPU performs inference on this batch, and send then the results to the CPU workers (through another Queue).
However, we do not want our CPU to wait for the bucket to be filled. So, after a CPU worker send a request to the GPU, he continue to 'dive' into the tree for another position (without waiting the results from the GPU). In order not to end on the exact same position, we can use 'virtual loss' here. The CPU worker send therefore another position to evaluate to the GPU. Before he dives for a 3rd time, he checks his 'Queue' to see if the GPU has sent him back the results of the previous positions: if there is something in the Queue, he updates the tree (removing also the virtual loss), if not, he continues to 'dive' and look for a 3rd position to evaluate...
For an example of Processes communicating using Queues: https://github.com/AlphaZeroIncubator/AlphaZero/issues/24#issuecomment-644149207
To start, it would be simpler to implement a system where the CPU workers wait for the GPU, and then add the 'virtual loss stuff' if everything works.
In order to locate bottlenecks in the MCTS, I replaced the network evaluation by a random policy/value:
Ordered by: cumulative time
ncalls tottime percall cumtime percall filename:lineno(function)
808/1 0.005 0.000 1.392 1.392 {built-in method builtins.exec}
1 0.000 0.000 1.392 1.392 speed_test.py:1(<module>)
1 0.000 0.000 0.789 0.789 mcts.py:366(self_play)
9 0.001 0.000 0.787 0.087 mcts.py:248(mcts)
900 0.010 0.000 0.781 0.001 mcts.py:322(forward)
54 0.003 0.000 0.703 0.013 __init__.py:1(<module>)
900 0.008 0.000 0.673 0.001 mcts.py:290(backpropagate)
8 0.001 0.000 0.573 0.072 __init__.py:3(<module>)
525/3 0.003 0.000 0.561 0.187 <frozen importlib._bootstrap>:986(_find_and_load)
525/3 0.003 0.000 0.560 0.187 <frozen importlib._bootstrap>:956(_find_and_load_unlocked)
499/3 0.003 0.000 0.560 0.187 <frozen importlib._bootstrap>:650(_load_unlocked)
661/3 0.000 0.000 0.560 0.187 <frozen importlib._bootstrap>:211(_call_with_frames_removed)
451/3 0.002 0.000 0.560 0.187 <frozen importlib._bootstrap_external>:777(exec_module)
499/304 0.001 0.000 0.377 0.001 <frozen importlib._bootstrap>:549(module_from_spec)
43/11 0.000 0.000 0.368 0.033 <frozen importlib._bootstrap_external>:1099(create_module)
43/11 0.159 0.004 0.368 0.033 {built-in method _imp.create_dynamic}
901 0.001 0.000 0.308 0.000 _game.py:268(is_game_over)
902 0.005 0.000 0.307 0.000 _game.py:280(result)
902 0.021 0.000 0.294 0.000 _game.py:239(get_game_status)
388/24 0.001 0.000 0.278 0.012 {built-in method builtins.__import__}
349 0.049 0.000 0.229 0.001 mcts.py:125(expand)
488/249 0.002 0.000 0.213 0.001 <frozen importlib._bootstrap>:1017(_handle_fromlist)
19729 0.010 0.000 0.203 0.000 tensor.py:25(wrapped)
14266 0.149 0.000 0.149 0.000 {method 'eq' of 'torch._C._TensorBase' objects}
900 0.027 0.000 0.129 0.000 mcts.py:83(calc_policy_value)
1463 0.043 0.000 0.107 0.000 _game.py:179(board_after_move)
451 0.017 0.000 0.104 0.000 <frozen importlib._bootstrap_external>:849(get_code)
6913 0.029 0.000 0.074 0.000 {built-in method builtins.any}
While I do not know what the line "frozen importlib" refers to, it seems that the bottleneck is checking if the game is over or not. (which is quite intuitive)
I think "frozen importlib" refers to trying to import a library that was already previously imported. e.g.
for _ in range(10000):
import pandas
would trigger that.
On Wed, Jul 8, 2020, 10:46 AM homerours notifications@github.com wrote:
In order to locate bottlenecks in the MCTS, I replaced the network evaluation by a random policy/value:
Ordered by: cumulative time
ncalls tottime percall cumtime percall filename:lineno(function) 808/1 0.005 0.000 1.392 1.392 {built-in method builtins.exec} 1 0.000 0.000 1.392 1.392 speed_test.py:1(
) 1 0.000 0.000 0.789 0.789 mcts.py:366(self_play) 9 0.001 0.000 0.787 0.087 mcts.py:248(mcts) 900 0.010 0.000 0.781 0.001 mcts.py:322(forward) 54 0.003 0.000 0.703 0.013 init.py:1( ) 900 0.008 0.000 0.673 0.001 mcts.py:290(backpropagate) 8 0.001 0.000 0.573 0.072 init.py:3( ) 525/3 0.003 0.000 0.561 0.187 :986(_find_and_load) 525/3 0.003 0.000 0.560 0.187 :956(_find_and_load_unlocked) 499/3 0.003 0.000 0.560 0.187 :650(_load_unlocked) 661/3 0.000 0.000 0.560 0.187 :211(_call_with_frames_removed) 451/3 0.002 0.000 0.560 0.187 :777(exec_module) 499/304 0.001 0.000 0.377 0.001 :549(module_from_spec) 43/11 0.000 0.000 0.368 0.033 :1099(create_module) 43/11 0.159 0.004 0.368 0.033 {built-in method _imp.create_dynamic} 901 0.001 0.000 0.308 0.000 _game.py:268(is_game_over) 902 0.005 0.000 0.307 0.000 _game.py:280(result) 902 0.021 0.000 0.294 0.000 _game.py:239(get_game_status) 388/24 0.001 0.000 0.278 0.012 {built-in method builtins.import} 349 0.049 0.000 0.229 0.001 mcts.py:125(expand) 488/249 0.002 0.000 0.213 0.001 :1017(_handle_fromlist) 19729 0.010 0.000 0.203 0.000 tensor.py:25(wrapped) 14266 0.149 0.000 0.149 0.000 {method 'eq' of 'torch._C._TensorBase' objects} 900 0.027 0.000 0.129 0.000 mcts.py:83(calc_policy_value) 1463 0.043 0.000 0.107 0.000 _game.py:179(board_after_move) 451 0.017 0.000 0.104 0.000 :849(get_code) 6913 0.029 0.000 0.074 0.000 {built-in method builtins.any} While I do not know what the line "frozen importlib" refers to, it seems that the bottleneck is checking if the game is over or not. (which is quite intuitive)
— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/AlphaZeroIncubator/AlphaZero/issues/41#issuecomment-655600768, or unsubscribe https://github.com/notifications/unsubscribe-auth/AEOFNOSH6NDJTLIDXBH3K5DR2SIDTANCNFSM4OUVD7CA .
oh... does this means that some libraries where imported ~500 times ?
I think so. I'm not sure about the above - i'm talking from memory, so I'd have to look it up to make sure.
But if so, then the fix is to move the imports to somewhere where they only execute once.
On Wed, Jul 8, 2020, 11:06 AM homerours notifications@github.com wrote:
oh... does this means that some libraries where imported ~500 times ?
— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/AlphaZeroIncubator/AlphaZero/issues/41#issuecomment-655612326, or unsubscribe https://github.com/notifications/unsubscribe-auth/AEOFNOQU4F3GWI7JQTZEAT3R2SKRNANCNFSM4OUVD7CA .
Thanks Leo, this is great. Regarding the import, I'm not sure where/how the import is run multiple times. I checked all the files and all imports are at the start of the files, outside of function definitions. As a test we could try running a file which only has the imports and nothing else.
Thanks for running the tests Leo!
It would be nice to check the results for Connect4, but I was unable to run it. Did you get similar results when running MCTS?
I haven't gotten to finishing the Connect4 yet, I think all that's left might be first for me to write a testing file for the class to make sure it works. As we were talking at our last meeting, I might leave the slightly more inefficient solution there for the meantime, seeing as a standard Connect4 Board isn't that large, and if it proves to be a bottleneck, I can write a slightly more efficient solution. Will try to get to work on it hopefully soon...
To add, I will start looking into implementing your suggestion for parallelization.
Trying with an 'handcrafted' TicTacToe result checker (https://github.com/AlphaZeroIncubator/AlphaZero/blob/85a3a130b8a953d73578c3a54d3b0011475f0f5d/alphazero/_game.py#L285-L307) makes the total time go from ~0.3 s to ~0.07s (5 times faster). I will try to think on ways of making the one for Connect4 faster (an option could be to use Cython for this particular function...)
I haven't gotten to finishing the Connect4 yet, I think all that's left might be first for me to write a testing file for the class to make sure it works. As we were talking at our last meeting, I might leave the slightly more inefficient solution there for the meantime, seeing as a standard Connect4 Board isn't that large, and if it proves to be a bottleneck, I can write a slightly more efficient solution. Will try to get to work on it hopefully soon...
I added a function winning_move
to the Connect4 class that checks if the last move is winning: 577f287
(I took the freedom to commit twice on the Connct4 branch, I hope it's ok for you).
I also had a concern about the get_legal_moves
function: https://github.com/AlphaZeroIncubator/AlphaZero/blob/577f287aeee0f15c4d6e6d64daf7673ac3b75d1b/alphazero/_game.py#L465
Should'nt this function return a list of integers corresponding of the columns that are not full ? That is do something like:
return torch.where( board[0,:] == -1)[0]
I had a similar concern with the dimension of the PolicyHead of the network, that might not necessarily be of the dimensions of the board: https://github.com/AlphaZeroIncubator/AlphaZero/issues/37#issuecomment-655551837
I did some test with Cython for the function that checks if a given move is winning (at connect4) here.
Result: using Cython leads to 15x to 20x speedup :smile_cat: (And I am new to Cython, one can possibly improve on that). Hence, if the bottleneck is still the game result checking, one could possibly compile some critical functions in Cython (which is very easy to do).
That's amazing. Thanks so much Leo :)
On Thu, Jul 9, 2020, 10:21 AM homerours notifications@github.com wrote:
I did some test with Cython for the function that checks if a given move is winning (at connect4) here https://github.com/AlphaZeroIncubator/AlphaZero/tree/self_play_speed_test/alphazero/speed_test/demo_Cython .
Result: using Cython leads to 15x to 20x speedup 😸 (And I am new to Cython, one can possibly improve on that). Hence, if the bottleneck is still the game result checking, one could possibly compile some critical functions in Cython (which is very easy to do).
— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/AlphaZeroIncubator/AlphaZero/issues/41#issuecomment-656191521, or unsubscribe https://github.com/notifications/unsubscribe-auth/AEOFNOTJE4W342CWZK5IUSLR2XN77ANCNFSM4OUVD7CA .
Thank you so much Leo! I've never heard of Cython so something new to look into.
I also had a concern about the
get_legal_moves
function: https://github.com/AlphaZeroIncubator/AlphaZero/blob/577f287aeee0f15c4d6e6d64daf7673ac3b75d1b/alphazero/_game.py#L465Should'nt this function return a list of integers corresponding of the columns that are not full ? That is do something like:
return torch.where( board[0,:] == -1)[0]
I had a similar concern with the dimension of the PolicyHead of the network, that might not necessarily be of the dimensions of the board: #37 (comment)
Fixed this, but rather than returning the torch.where(), I'm thinking just:
return board[0,:] == -1
since in the TTT Sid wrote, we are returning a boolean tensor of the legal moves.
I did some short test for self-play on the server. I needed to modify few things, so I created a branch 'self_play_speed_test'. I used cProfiler for the profiling, here is the output of
python -m cProfile -s cumulative speed_test.py
(the result are sorted by cumulative time), for a TicTacToe game of 9 moves and 100 mcts run per move:The self-play took about 9s, and 7.8s are used by the network evaluations (
calc_policy_value
) so that the MCTS logic takes the remaining 1.2s. It would be nice to check the results for Connect4, but I was unable to run it. Did you get similar results when running MCTS?So, we need to improve the performance of the 'inference task' (and also MCTS if possible), using for instance multiple processes.