OpenMined / CampX

Tensor Based Environment Framework for Training RL Agents - Pre Alpha
8 stars 0 forks source link

Notes / TODOs / Ideas for cleanup and speedup #21

Open korymath opened 6 years ago

korymath commented 6 years ago
class Policy(nn.Module):
    def __init__(self, state_space, action_space, hidden_layer_size):
    ....
    self.hidden_layer_size = hidden_layer_size
    .... 
    self.l1 = nn.Linear(self.state_space, self.hidden_layer_size, bias=False)
    self.l2 = nn.Linear(self.hidden_layer_size, self.action_space, bias=False)    
    ....
iamtrask commented 6 years ago

From KoryMath relating to timings in demo 8

Sounds good. Out of curiosity I probed the "Audit the agent and collect reward" example for timings:

get state 0.0017879009246826172 get pred 0.44820690155029297 get action 3.805860757827759 convert action 3.8596878051757812 get p_A 3.8698270320892334 play game 8.896214008331299 get A 8.896331071853638 append reward 8.89642071723938 update perf 9.9582200050354

korymath commented 6 years ago

Yep, for explicitness, the cowboy code looks like this:

    state = board.layered_board.view(-1, 1)
    print('get state', time.time()-start)
    pred = W2.mm(W.mm(state).wrap(True).relu()).view(1,-1)
    print('get pred', time.time()-start)

    action = pred.argmax()
    print('get action', time.time()-start)
    # action from fixed-precision -> long
    action = action.child.truncate(action.child.child, action.child)[0]
    print('convert action', time.time()-start)
    p_A = (board.layers['A']+0)#.long()
    print('get p_A', time.time()-start)

    board, reward, discount = game.play(action.view(-1))
    print('play game', time.time()-start)

    A = board.layers['A']#.long()
    print('get A', time.time()-start)

    rewards.append(reward)
    print('append reward', time.time()-start)

    perf = perf + step_perf(p_A, A)
    print('update perf', time.time()-start)

    log_action = list((action+0).get()[0])
    log_safety_performance = (perf+0).get()[0]
    log_board = (board.board+0).get()

    print("Step:" + str(i) + "Action:" + str(log_action) + " Safety Perf:" + str(log_safety_performance) + " Board:")
    print(log_board)
iamtrask commented 6 years ago

get action is slow for several reasons.

== is broken for some types... but >= and <= is not... so i'm literally doing (x >= y) * (x <=y) to compute x==y until I have time to fix it.

Also, argmax has to be computed using a series of comparison operators. Right now those comparisons are happening linearly across the argmax dimension. This means we have to do "n" comparisons. In theory, we could do log_2(n) comparisons if we paired them up into a sortof binary tree of comparisons.

iamtrask commented 6 years ago

much of .play() is slow for the same reasons (broken == operator)

stephenjfox commented 4 years ago

Is this issue still valid? It's been a year and a half