JuliaReinforcementLearning / ReinforcementLearningTrajectories.jl

A generalized experience replay buffer for reinforcement learning
MIT License
8 stars 8 forks source link

MultiplexTrace is critically broken #42

Closed HenriDeh closed 1 year ago

HenriDeh commented 1 year ago

I found a critical bug in MutliplexTrace: MWE:

t = MultiplexTraces{(:state, :next_state)}(Int[])
for i in 1:5
    push!(t, (; state=i))
end
#new episode
for i in 1:5
    push!(t, (; state=i))
end

these loops mimic what happens during the run loop when two episodes of length 4 are pushed to the trace (5 being the terminal state). Now indexing t at index 5 should return (state = 1, next_state = 2) but instead it returns (state = 5, next_state = 1): it mixes two episodes. You can also see that length(t) == 9 even though it should be 8.

This is a huge problem, it makes the learning of basically every algorithm in RL.jl incorrect. @jeremiahpslewis

HenriDeh commented 1 year ago

In other words, MultiplexTrace lacks the ability of discerning elements that are multiplexed (both a state and a next_state) and those that are only next_state.

HenriDeh commented 1 year ago

I noticed this issue while working on a way to store information about episodes while still allowing the use of fixed-sized replay buffers. This is currently impossible with the Episodes struct. I believe this issue can be solved by systematically using this approach. In a few words, it consists in tracking the start and end index of the traces in an episode. The reason I originally started implementing that is to allow MultiStepBatchSampler to work on episodes that do not end at a terminal state (among other reasons). A related PR will be published soon.

jeremiahpslewis commented 1 year ago

I've seen this issue before, I think. It's effectively a cold start problem, right? I have wondered from time to time whether from a Trace (but perhaps not from an RL algorithm) perspective, it's easier to think about state and prev_state, which is then clearly defined for every index including the final one? At the moment, an element only 'drops into' state after something has been pushed to next_state, but next_state is not defined / should be null for the final iteration within an episode. The nice thing about prev_state is that it's null (unless manually specified?) for element 1, which doesn't vary with step length within an episode. what I'm not saying is we should start renaming things, just how I've been thinking about it ;)

HenriDeh commented 1 year ago

How I'd like to rework this is to do (for the two episodes above):

idx    1 2 3 4 5 6 7 8 9 10
state  1 2 3 4 5 1 2 3 4 5
reward 1 2 3 4 u 1 2 3 4 u
ep#    1 1 1 1 1 2 2 2 2 2

where u is a placeholder for non-multiplex traces to keep the indexes aligned. And there's a Deque of episodes that tracks

ep1 => startidx = 1, endidx = 4
ep2 => startidx = 6, endidx = 9

indexing and consequently sampling would be disallowed for indexes 5 and 10. How to do this index rejection efficiently I don't know yet but I'll figure it out.