abhisheknaik96 / differential-value-iteration

Experiments in creating the ultimate average-reward planning algorithm
Apache License 2.0
0 stars 2 forks source link

Async state order? #16

Open btanner opened 3 years ago

btanner commented 3 years ago

I noticed that the current algorithms seem to have behaviour like:

    idx = self.idx
    self.v[idx] +=  update_here
    self.idx = (self.idx + 1) % self.num_states

So, the async updates sweep through each state in a fixed order. Is this the general behaviour we want?

Some alternatives:

I'm asking for a few reasons. One is that down the road, running these async algorithms in JAX, we might want to do: 1 <= N << Num_States updates in parallel. I'm trying to think of what might be the appropriate generalization of the current code to an intermediate approach like that.

@abhisheknaik96 Do you want to add @yiwan-rl to the project formally so I can assign issues to him?

yiwan-rl commented 3 years ago

You are right. The alternatives you mentioned are all valid approaches. In fact, we would hope the algorithm to be totally asynchronous, i.e., states can be updated in an arbitrary way asynchronously, as long as all states are updated for an infinite number of times in the limit. This is also the common condition for other asynchronous iterative algorithms (e.g., async discounted value iteration).

The totally asynchronous condition is very general. It even allows delays to happen, that is, an update does not need to use the most-updated state values. The conceptual computational model under the totally asynchronous condition is the following: each processor maintains and updates one state's value using values received from all processors and sends the updated value to all other processors. Each processor works without waiting for results from other processors. There are some unknown communication delays between these processors and thus a value used for an update may not be the most recent one.

I am not sure how JAX computes things in parallel. Maybe Brian has some idea?

abhisheknaik96 commented 3 years ago

Out of all those valid options, I would prefer picking randomly for every update because:

For parallelizing, perhaps we can use JAX/numpy to pick a 'mini-batch' of N random indices at each step and update?

The intuition behind the totally asynchronous updates is really interesting — thanks for sharing, Yi! P.S.: Yi, I've sent you an invite so that you'll be a 'collaborator' instead of just a 'contributor'.