Open daniellawson9999 opened 1 year ago
First, thanks for using Minari! And these questions are really helpful for us; it's difficult to refine a product without hearing from users.
OK so for the first part. We are working on a optional https://docs.mosaicml.com/projects/streaming/en/stable/ streamingdataset backend. We are open to design suggestions for parallel sampling for both a streaming dataset and h5py backend(I'm not sure what we need to do to get true parallelism in python in the setting of memory shared between physical threads, maybe this is easy). It may be that a parallel sampling implementation could be a strict improvement over our current implementation on any machine with more than once physical CPU core.
For the second one. We don't have any built-in features for sampling from multiple datasets at once. The closest thing that comes to mind is generating the list of indices to sample externally to minari, then sampling from each dataset using iterate_episodes
with that list as an argument (you can also use square brackets directly on the MinariDataset
object to get a episode by index). That will give you fine-grained enough control to sample without replacement, or sample the same indices from different datasets, etc... We have sub-episode trajectory sampling code in development also.
We are open to feature requests, so feel free to propose any features you think would support your use-case.
Thanks for the response! Regarding the development of the streaming dataset backend, is this currently in a public fork? Curious to just to take a look and see if I could patch together something similar in the mean time before this becomes an official feature.
Have you tried by now with joblib library ? Also just using torch's Dataloader should work using n processes.
from joblib import Parallel, delayed
import minari
class MinariParallelWrapper:
def __init__(self, dataset_name):
# Load the dataset using Minari
self.dataset = minari.load_dataset(dataset_name)
def get_episodes_parallel(self, n, n_jobs=-1):
"""Fetch 'n' episodes in parallel using Joblib."""
return Parallel(n_jobs=n_jobs)(
delayed(self.dataset.get_episode)(i) for i in range(n)
)
# Add other dataset methods here, with or without parallelism
def get_metadata(self):
return self.dataset.metadata
def episode_statistics(self):
return self.dataset.episode_statistics()
# Wrapping other methods (optionally parallelized)
# For example, getting rewards from episodes:
def get_rewards_parallel(self, n, n_jobs=-1):
"""Fetch rewards from 'n' episodes in parallel."""
return Parallel(n_jobs=n_jobs)(
delayed(lambda ep: ep["rewards"])(self.dataset.get_episode(i)) for i in range(n)
)
# Example usage
if __name__ == "__main__":
# Assume dataset already created and named "CartPole-v1-dataset"
wrapper = MinariParallelWrapper("CartPole-v1-dataset")
# Get 10 episodes in parallel
episodes = wrapper.get_episodes_parallel(10)
print(episodes)
# Example of getting rewards from the first 5 episodes in parallel
rewards = wrapper.get_rewards_parallel(5)
print(rewards)
Parallel episode sampling
I have a use case where we have a dataset consisting of image-based observations, and I notice that sampling speed seems to be slower than with 1D observations. I checked out how sampling is working internally, and noticed that Minari samples episodes serially, instead of sampling in parallel. I thought that parallelizing this call may have been thought about already, so I was curious for any recommendations on the best way to do this. I was also wondering if this was something that will be added in the future.
I have one more layer of complexity on top of this, where instead of 1 dataset, I have say 10 datasets from different envs, each have image-based observations. Think multi-task Atari. I have 10 minari datasets, and then say want 30 episodes from each for each gradient update. Also want to do this in parallel, and will experiment with different parallelization techniques but curious if others had intuition about this.
https://github.com/Farama-Foundation/Minari/blob/c0669fc3a8829dec4a7a1fbee198a6be4f668ea1/minari/dataset/minari_storage.py#L153-L180