ray-project / ray

Ray is an AI compute engine. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.56k stars 5.7k forks source link

[RLlib] Customizable SampleBatch functions #37767

Open ULudo opened 1 year ago

ULudo commented 1 year ago

Description

I'm not quiet sure if this is a bug or a feature, because as I understand it, the SampleBatch should not assume a specific data structure.

The ViewRequirement class can be used to specify data ranges for the individual parameters of a SampleBatch, e.g.:

self.view_requirements[SampleBatch.TERMINATEDS] = ViewRequirement(SampleBatch.TERMINATEDS, shift=f"-24:0")

This creates 2D arrays for the TERMINATEDS.

However, these array structures are not supported by the SampleBatch class. The is_terminated_or_truncated and is_single_trajectory functions expect 1D arrays for the TERMINATEDS and TRUNCATEDS properties. If the content of the properties is 2D, these functions cause errors or return wrong values.

In my case this can be easily fixed by changing the code of the functions as follows:

But these changes are application specific and I had to change the source for that. It would be of great help if the functions in the SampleBatch class were customizable (Or are they already and I just didn't find it in the documentation?).

Use case

No response

ArturNiederfahrenhorst commented 1 year ago

Hi @ULudo ,

This is a major painpoint in RLlib's current implemention. As you say, SampleBatch makes some assumptions about the underlying data, where often it would be better if it was just a simple Container. Basically just a dict. Many parts of RLlib are, today, tightly integrated with the internal workings of SampleBatch and so it is extremely non-trivial to make changes there. Luckily, we've already decoupled our new training backend (The new RLModules/Learner stack) almost entirely from this. As RLlib progresses, we'll attempt to do the same with the sampling backend (RolloutWorker, EnvRunnerV2, AgentCollector, ...). Such that we'll be able to disentangle this logic bit by bit.

Thanks for reporting your issue here. It will probably take a little while until we are able to pick this up. Not because it is not important, but because it depends on many more things that have to be done before we can get to it.