pytorch / data

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.
BSD 3-Clause "New" or "Revised" License
1.13k stars 153 forks source link

Support Key/Value databases #711

Open msaroufim opened 2 years ago

msaroufim commented 2 years ago

🚀 The feature

The existing cacheholder leverages a python dictionary

@functional_datapipe("in_memory_cache")
class InMemoryCacheHolderMapDataPipe(MapDataPipe[T_co]):
    def __init__(self, source_dp: MapDataPipe[T_co]) -> None:
        self.source_dp: MapDataPipe[T_co] = source_dp
        self.cache: Dict[Any, T_co] = {}

    def __getitem__(self, index) -> T_co:
        if index not in self.cache:
            self.cache[index] = self.source_dp[index] 
        return self.cache[index]  # type: ignore[index]

But could instead provide a generic interface to plug in different cache providers like redis or memcached

class Cache(ABC):
  @abstractmethod
  def __init__():
    pass

  @abstractmethod
  def __getitem__(self,index) -> T_co
    pass

class RedisCache(Cache):
  def __init__(self, url):
    setup_redis_client(url)
  def __getitem__(self,index) -> T_co:
    return NotImplementedError

class MemCacheCache(Cache):
  def __init__(self,url):
    setup_memcache_client(url)
  def __getitem__(self,index) -> T_co:
    return NotImplementedError

Motivation, pitch

Python dictionaries have a few limitations when used as a cache

  1. Need to copy them per process
  2. Updating cache in one process needs to manually synchronize with cache in other processes
  3. Need to load the entire dictionary in memory to potentially look at a single element

So the goal of this work would be to reduce memory overhead and cache misses of cache in multiprocessing environments while sacrificing latency on cache hit because a python dictionary will be faster to access than a remote KV store

The 0.4 release was very much about leveraging remote object stores so this work would follow that trend

Alternatives

No response

Additional context

Our queues right now use python lists https://github.com/pytorch/data/blob/main/torchdata/dataloader2/communication/queue.py#L11 but could instead leverage queues like Kafka or RabbitMQ so imagine a similar solution

ejguan commented 2 years ago

It's a reasonable feature for MapDataPipe. Would it be possible to extend this cache interface to support in-memory cache for IterDataPipe, which follows FIFO manner.

BTW, I think __contains__ is also required for cache object to check if the request has already been in cache.

msaroufim commented 2 years ago

Summarizing notes from meeting with Erjia

Spartee commented 2 years ago

After talking with @msaroufim I wanted to take a stab at this implementation. Below is the prototype code I put together solely for the purpose of discussion. this uses the IterDatapipe class but the map one is probably even simpler.

Once we figure out a good strategy for handling some of the design items below I can


def setup_redis_client(url, username=None, password=None):
    # TODO Implement with ACL support
    # TODO Implement with TLS support?
    # TODO Redis cluster client??
    try:
        import redis
        return redis.Redis(url)
    except ImportError as e:
        print("Redis needs to be installed in order to use Redis cache for datapipes")
        raise

@functional_datapipe("redis_cache")
class RedisCacheHolderIterDataPipe(IterDataPipe[T_co]):

    def __init__(self, source_dp: IterDataPipe[T_co], redis_url: str, cached_elements: Optional[int] = None) -> None:
        self.source_dp: IterDataPipe[T_co] = source_dp
        self._client = setup_redis_client(redis_url)
        self._key = "tpipe"
        self._start_idx = 0
        # use number of cached elements rather than cache size
        # avoids problem of using Redis DB size when Redis being used for
        # more than just a datapipe cache
        self.cached_elements = cached_elements

    def _iter_stored(self):
        # index always starts at 0 for redis list
        # _start_index solely for tracking number of stored elements
        for idx in range(0, self._cache_list_len()):
            # LRANGE? or Pipeline??
            yield self._deserialize(self._client.lindex(self._key, idx))

    def _deserialize(self, response):
        return pickle.loads(response)

    def _serialize(self, value):
        # TODO store datatype in Redis upon init? how to assert datatype
        # dont serialize for primiative datatypes? only collections?
        return pickle.dumps(value)

    def __iter__(self) -> Iterator[T_co]:
        if self._cache_list_len() > 1:
            for idx, data in enumerate(self.source_dp):
                print(data)
                if idx < self._start_idx:
                    yield data
                else:
                    break
            yield from self._iter_stored()
        else:
            for data in self.source_dp:
                self._client.rpush(self._key, self._serialize(data))

                # Cache reaches element limit
                if self.cached_elements is not None and self._cache_list_len() > self.cached_elements:
                    self._client.lpop(self._key) 
                    self._start_idx += 1
                yield data

    def __contains__(self, key):
        return self._client.exists(key)

    def _cache_list_len(self):
        return self._client.llen(self._key)

    def __len__(self) -> int:
        try:
            return len(self.source_dp)
        except TypeError:
            # if list has been created in the database
            if self._key in self:
                return self._start_idx + self._cache_list_len()
            else:
                raise TypeError(f"{type(self).__name__} instance doesn't have valid length until the cache is loaded.")

When running the following simple example

from torchdata.datapipes.iter import IterableWrapper
source_dp = IterableWrapper(range(10))
cache_dp = source_dp.redis_cache(redis_url="localhost")
print(list(cache_dp))

I get the expected answer

>>> [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

I've also tested with the ag_news example augmented to use the redis cache

# Stack CSV Parser directly on top of web-stream
dp = HttpReader([URL[split]]).parse_csv()
cache_dp = dp.redis_cache(redis_url="localhost")
return cache_dp.map(_process_tuple)

A couple design points to mention

  1. I went for the cached_elements instead of size since it'll be tough to directly correlate list item size to cache size unless that is tracked separately. This provides an optional layer on top of the existing eviction strategies within Redis.
  2. start_idx is used in the case where cached_elements is triggered. This simply assists with the fact that redis lists always start with 0.
  3. I wasn't sure how to handle the __len__ function as this is my first pass at this library so please chime in if that doesn't look right.

Design Points

I'll keep chipping away at this, but wanted to post early to gather feedback.

msaroufim commented 2 years ago

Hi @Spartee thank you for your patience here are my thoughts

Overall I think a PR doing what you describe should be pretty easy to merge with the value being that a Redis cache is yet another data source that anyone could leverage

Regarding some of your specific design questions

  1. cached_elements is a good idea, although I wonder is it common for people to use Redis caches for more than one use case or do folks typically have a cache per application?
  2. start_idx looks good
  3. __len__: for an iterable pipe I'd just raise an error https://github.com/pytorch/data/blob/main/torchdata/datapipes/iter/load/huggingface.py#L79-L80 otherwise return the cache length

On the future design points

ejguan commented 2 years ago

@Spartee Thank you for putting up a prototype! Here are my thoughts.

Aside from optimize retrieval, we might be able to provide multiple-layer cache to reduce cach miss. For redis example, cached_elements is used to define the cache size. When we runs out of cache, we might fall back to the another layer of cache (maybe by default use a on-disk file - this might need the format discussed with @msaroufim ).

client init

Noob question about how redis client/server works in python. Let's say we have multiple processes running the same pipeline, would each redis client attach to the same server? If that's the case, it seems like we need to figure out a way to make sure the order of data is preserved because normally the data should be in round-robin order. (Process 0: [0, 2, 4, 6, ...], Process 1: [1, 3, 5, 7, ...])

Serialization

We currently still rely on pickle as pytorch core is still depending on pickle. But, cache itself could implement their own serialization strategy inside __getstate__ and __setstate__ to eliminate the security issue. And, within each DataPipe, pickle will invoke __getstate__ function and the inner serialization logic can be called before pickling the cache data.

Spartee commented 2 years ago

Happy to! Some responses

cached_elements is a good idea, although I wonder is it common for people to use Redis caches for more than one use case or do folks typically have a cache per application?

I think it's most common to have one cache per application, but I have seen some places where a number of microservices use the same cache for a number of purposes (pubsub/brokering/caching). whatever we decide here should be well documented. I lean towards a more "hands off" approach here as the size checks would need to occur regularly on the client side which, under load, is not a negligible cost.

Serialization

Both valid options presented by @ejguan and @msaroufim. the out-of-band performance of pickle (protocol 5) is quite good and I know other pydata libraries use it. If we use dill, I would want to include that in a pip extra like torchdata[redis] for consistency.

I'm going to play around with serialization/compression and come back with some options.

multiple-layer cache to reduce cach miss.

This is a strategy that I think should be handled on the server side. Some variants of Redis, esp managed ones like Redis Enterprise, give flash support for handing tiered caching. This is more performant than multiple requests on the client side. There are some OSS variants that provide tiered caching as well that are consistent with the OSS redis API.

client init

If we solely focus on multiprocessing settings, I would think the best solution would be to init a client within each process at startup and keep them alive until the object is destroyed. How would the processes be initilaized? user passes them in? Can you point me to any examples that use a single datapipe with multiple processes?

ejguan commented 2 years ago

How would the processes be initilaized? user passes them in? Can you point me to any examples that use a single datapipe with multiple processes?

Here are some references. DataLoader2 would rely on MultiprocessingReadingService to spawn processes. https://github.com/pytorch/data/blob/86df1a09c0f649aca195a233508669de35b8623b/torchdata/dataloader2/reading_service.py#L147-L166

The datapipe should be automatically sharded based on the process number. Here would be a minimum example for multiprocessing.

from torchdata.datapipes.iter import IterableWrapper
from torchdata.dataloader2 import DataLoader2, PrototypeMultiProcessingReadingService

if __name__ == "__main__":
    input_dp = IterableWrapper(list(range(100))
    dp = input_dp.shuffler().sharding_filter()

    rs = PrototypeMultiProcessingReadingService(num_workers=2)
    dl = DataLoader2(dp, reading_service=rs)

    for d in dl:
        print(d)