AlphaZeroIncubator / AlphaZero

Our implementation of AlphaZero for simple games such as Tic-Tac-Toe and Connect4.
0 stars 0 forks source link

MCTS improvements #24

Open PhilipEkfeldt opened 4 years ago

PhilipEkfeldt commented 4 years ago
homerours commented 4 years ago

I discussed a bit about the "batch inference server" with a friend, who suggested instead of having CPU workers exploring the tree that do requests to the GPU inference service, it would probably easier to have a "master GPU thread" that asks CPU workers for new positions to evaluate. Something like:

net= NN()
net().half.to('cuda') # store the network on the GPU, in halffloat
pool = ThreadPool(nb_thread) # create a pool of threads
a=list(range(nb_thread))

while games not ended:
    results = pool.map(query_MCTS, a) # run parallely the function query_MCTS on each thread of the pool
    input0 = torch.cat(results,dim=0) # concat the results
    input0 = input0.half().to('cuda') # convert to halffloat and send to GPU
    output0 = net(input0)  # process the batch
    # then, send the results of the outputs to the MCTS trees
homerours commented 4 years ago

In fact, this seems to lead to lot of overhead: launching processes take some time and the CPU workers are not "filling the board bucket" while the GPU is working. It would probably be more efficient to have something completely asynchronous: an 'inference server' (autonomous Python process) that will be fed by a Queue. I'll do some tests!

PhilipEkfeldt commented 4 years ago

Thanks Leo! Yeah, I think the first suggestion would be kind of complicated as well since we'd need to return the state of the tree every time so we can continue it after we get the inference result?

homerours commented 4 years ago

Here is the test I did (sorry for the dirty copy/paste, I did not wanted to interfer in the git repo):

import torch
import time
import random
from resnet import Net
import torch.multiprocessing as mp

def CPU_worker(q,nb_pos,delta):
    # A CPU worker will put nb_pos board positions
    # in the queue q
    # every ~delta seconds
    game_play = {'width': 7,'height':8}
    i=0
    while True:
        i += 1
        time.sleep(delta*random.random())
        s =torch.randn(nb_pos, 2, game_play['width'], game_play['height'])
        q.put(s)

def merger(q1,q2,group_size):
    # The roal of the merger is to merge in a single tensor the requests
    # of the CPU workers.
    # He could also be used to test if the request has already been processed
    i=1
    while True:
        s=0
        batch=[]
        while s<group_size:
            r=q1.get()
            batch.append(r)
            s+=1

        q2.put(torch.cat(batch,dim=0))
        if (i % 10 == 0 and q1.qsize() > group_size):
            print('First queue is full: merger does not merge fast enough!')
        i+=1

def GPU_service(q,ev,nbEpoch):
    # GPU inference service
    # process nbEpoch batchs of tensors in the ResNet
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)

    game_play = {'width': 7,'height':8} # ~connect4 dims
    net = Net(2,64,game_play,10) # 64 filters, 10 residual filters
    print('sending NN to GPU...')
    net.half().to(device)
    print('SENT !')
    ev.set() # let the main thread know that the GPU has been loaded,
    # so that he can start the CPU workers
    i=0
    while (i< nbEpoch):
        if (i % 10 == 9):
            print('EPOCH ' + str(i+1))
            print(str(q.qsize()) + ' batches in the GPU queue.')
        input0 = q.get() # get next batch
        input0 = input0.half().to(device)

        # DO INFERENCE
        output0 = net(input0)
        torch.cuda.synchronize()  # Cuda is asynchronous so we wait for the end of the comp
        i +=1
    print('END cuda')

if __name__ == '__main__':
    mp.set_start_method('spawn')

    queueCPU = mp.Queue() # Queue fed by the CPUs
    queueGPU = mp.Queue() # Queue read by the GPU
    ev = mp.Event()
    nbEpoch = 100 # nb of batchs processed by the GPU
    # Each CPU worker will request nb_pos positions evaluations every delta_time seconds:
    nb_pos = 100
    delta_time=0.01
    # these nb_pos will be merged into larger groups by the 'merger'
    batch_size = 100
    # hence, the batchs sent to the GPU will contain batch_size*nb_pos positions

    # total number of positions evaluated by the GPU:
    total_requests = nbEpoch*batch_size*nb_pos

    r1 = mp.Process(target=GPU_service, args=(queueGPU,ev,nbEpoch))
    g1 = mp.Process(target=merger, args=(queueCPU,queueGPU,batch_size))
    print('Launch GPU process...')
    r1.start()
    g1.start() # LAunching merger
    ev.wait()
    print('GPU process launched!')

    nb_process = 6
    processes=[mp.Process(target=CPU_worker, args=(queueCPU,nb_pos,delta_time)) for i in range(nb_process)]
    i=0
    for p in processes:
        i+=1
        p.start()
        print('Launched CPU worker ' + str(i))

    u1=time.time()
    r1.join()
    u2=time.time()

    g1.terminate()
    for p in processes:
        p.terminate()

    print('TERMINADO')

    time_per_pos = (u2-u1)/total_requests
    time_per_cycle = 120000000*time_per_pos / 3600
    print('Time/position:', time_per_pos)
    print('Hours/generation cycle:',time_per_cycle)
homerours commented 4 years ago

Here, 6 CPU workers send positions to a queue. The positions in the queue are then concatenate a in big tensor by a 'merger process' and then read by a GPU service.

This saturates the P40 GPU, which is able to process 1 position in about ~2 * 10^-5 s (depending on the depth of the ResNet), which should bring a game generation cycle (~120 millions evalutations) to about 1 hour.

homerours commented 4 years ago

Thanks Leo! Yeah, I think the first suggestion would be kind of complicated as well since we'd need to return the state of the tree every time so we can continue it after we get the inference result?

Yes. But the CPU workers should probably not wait for the GPU evaluation to continue exploring the tree: we could for instance use the 'virtual loss' approach to do that.

PhilipEkfeldt commented 4 years ago

This looks great Leo! I can't say I understand all of the code, but this looks to be exactly what we need. Hopefully I will be done with the non-parallel mcts implementation soon and we can start testing this. I will look into virtual loss as well.

homerours commented 4 years ago

Sure, the priority is definitely to make the non-parallel mcts work. This should also be a good benchmark to compare with parallel mcts.