rail-berkeley / rlkit

Collection of reinforcement learning algorithms
MIT License
2.45k stars 550 forks source link

Change sampling method from randint to choice in Replay and robustify policy networks in SAC #111

Closed ksluck closed 4 years ago

ksluck commented 4 years ago

This pull request proposes two changes:

Replacement of np.random.randint with np.random.choice in SimpleReplayBuffer

Using randint allows for the possibility of duplicated transitions in each batch which in effect means that the gradients/errors in these transitions have a higher influence on the updates. Using choice is here a better option as it prevents this behavior when replace=False. In order to be compatible with the current implementation replace is only set to False when size > batch_size. Otherwise we allow for duplicates (because the user should have checked for size>batch_size I assume). This simplified demo code highlights the issue:

>>> size = 100
>>> batch_size = 64
>>> indices = np.random.randint(0, size, batch_size)
>>> indices

array([82, 49, 37, 39, 19, 86, 86, 12, 44, 68, 86, 30, 59, 82, 20, 66, 12, 53, 99, 95, 56, 69, 96, 89, 2, 7, 93, 38, 54, 48, 16, 71, 58, 7, 29, 34, 18, 54, 4, 62, 14, 95, 75, 59, 69, 98, 54, 57, 8, 8, 54, 14, 76, 66, 77, 37, 78, 30, 71, 43, 99, 70, 51, 20])

>>> indices_2 = np.random.choice(size, size=batch_size, replace=size<batch_size)
>>> indices_2
array([36,  6, 87, 79, 93,  0, 62, 98, 95, 71, 18, 73, 92, 37, 55, 80, 19,
       43, 49, 74, 56, 39,  1, 45, 29,  5, 32, 78, 28,  9,  2, 41, 26, 64,
       44, 38,  3, 33, 85,  8, 60, 51, 22, 16, 89, 63, 52, 83, 75, 81, 17,
       82, 15, 88, 53,  7,  4, 77, 40, 25, 30, 84, 13, 50])

This change might have impact on all/other algorithms, so another round of tests might be a good idea here ;)

Small fix in the policy network of SAC

Just a small fix which should not have impact on any other method: Currently, the logprob is summed with .sum(dim=1) which assumes that we have only one batch dimension. This might cause issues if we have two or more batch dimensions, thus we should change that to .sum(dim=-1) to indicate that we want to sum over the data-dimension. This change makes the class more flexible and should not have any impact on current users.

vitchyr commented 4 years ago

Thanks for this! One reason I used random.randint is that it's quite a bit faster than choice, presumably because it doesn't try to prevent duplicates. Have you found that it actually makes a large difference? I imagine it wouldn't matter in many use cases, but it'd be nice to have this option. Instead of replacing the behavior, could you add a flag that allows people to choose between the two options (with the default being to use random.randint)? That way, it won't surprisingly change the training time for users, but gives people the option the sample without replacement.

Also, thanks for the policy network change!

ksluck commented 4 years ago

I would assume in most cases where data generation is much faster than the training process (like in simulations or the Atari Games) and we have a large number of steps per episode this should indeed not be much of an issue. The probability to select multiple times the same data for the batch will decrease quickly.

The impact will be larger when either the episodes have low numbers of steps or the training process outpaces the collection of training data if the simulations are complex or we collect data in the real world, and we have to / can train longer on the first few hundreds of steps collected. Well, the joy of doing robotics :smile:

Good point about the processing time - actually, if I remember it right choice uses randint when the replace flag is set to True so we could introduce a parameter to the class named replace which has the standard value of True. In that way we can still only use choice and get the expected behavior you describe. I am wondering what the best handling of the case of batch_size > size should be? Silently using replace=True when the flag for the class is set to False seems a bit unsatisfying, maybe throwing a Warning might be a good idea?

size = 10000000
batch_size=512
a = time.time(); indices = np.random.randint(0, size, batch_size); b=time.time()
print(b-a)
0.00011873245239257812

a = time.time(); indices_2 = np.random.choice(size, size=batch_size, replace=False); b=time.time()
print(b-a)
0.4596829414367676
a = time.time(); indices_2 = np.random.choice(size, size=batch_size, replace=True); b=time.time()
print(b-a)
0.00021719932556152344
ksluck commented 4 years ago

Added the proposed changes + warning if desired behaviour is not met