JuliaReinforcementLearning / ReinforcementLearningTrajectories.jl

A generalized experience replay buffer for reinforcement learning
MIT License
8 stars 8 forks source link

Agent cache broke normalization #39

Closed HenriDeh closed 1 year ago

HenriDeh commented 1 year ago

And I don't know how to fix it without importing RLCore (and all the unecessary dependencies with it).

Basically the functions at normalization.jl

for f in (:push!, :pushfirst!, :append!, :prepend!)
    @eval function Base.$f(nt::NormalizedTraces, x::NamedTuple)
        for key in intersect(keys(nt.normalizers), keys(x))
            fit!(nt.normalizers[key], x[key])
        end
        $f(nt.traces, x)
    end
end

are broken because now it's a SRT or a SA that is pushed. But these structs are not Base, unlike the former NamedTuple.

I cannot leave x untyped because I cannot guess which fields among .reward, .state, .action, .terminal (and any arbitrary other really) it has. Importing RLCore would allow to dispatch on SA and SRT but 1) we don't want these dependencies in this package and 2) It remains annoying as soon as we will be using another type of cache (one with action_log_prob for example).

HenriDeh commented 1 year ago

I accidentally pushed my commit to main instead of creating a branch but the only change is this: image I think it is unideal to use getfield and fieldnames but at least it works and it is a really simple solution. Let me know if you have any comment.