hal3 / macarico

learning to search in pytorch
MIT License
111 stars 12 forks source link

Bootstrap Exploration #20

Closed amrsharaf closed 6 years ago

amrsharaf commented 6 years ago

An implementation for the bootstrap exploration policy. A reference implementation is available here

Outline

bootstrap.py: The BootstrapPolicy class maintains a bag of LinearPolicy of size bag_size

__call__function queries all the policies on the bag, every policy votes for it's predicted action with probability 1/bag_size, based on the greedy_predict flag, either returns the action with the highest probability, or samples an action from the probability distribution

predict_costs similar to predict_costs in the LinearPolicy class, but queries every policy in the bag and return a list of all predict costs. The list is of type BootstrapCostrather than a standard python list. The reason is BanditLOLS needs a single aggregated vector of predicted costs for example in computing certainty scores, the BootstrapCost handles this aggregation by averaging the cost vectors.

forward_partial_complete computes the loss for the BoostrapPolicy class, updates every policy in the bag Poisson(1) times. If the greedy_udpate flag is true, the first policy is updated exactly once.

BootstrapCost a list of predicted costs, however, this class overrides the npvalue(), __getitem__, and __neg__() function, the class just aggregates the cost vectors by averaging and then calls npvalue(), __getitem__, or __neg__() on the aggregated vector.

actions_to_probs and bootstrap_probabilities are just helper functions for computing the bootstrap probabilities.

Updates to lols.py are pretty minimal:

  1. added bootstrap exploration to the do_exploration() function;
  2. refactored the code for computing certainty scores, moved the computation code to a single function that's then called three times for lols, banditlols and banditlols rewind;
  3. Added an extra exploration flag.
hal3 commented 6 years ago

This looks great! Can you:

  1. change the random sampling code to just calling np.random.poisson(1)?
  2. fix merge errors?
amrsharaf commented 6 years ago

refactored the poisson sampling code and resolved merge conflicts (was due to adding layers to linear policy)

Please have a one final look before merging