Open msaroufim opened 2 years ago
It's a reasonable feature for MapDataPipe
. Would it be possible to extend this cache interface to support in-memory cache for IterDataPipe
, which follows FIFO manner.
BTW, I think __contains__
is also required for cache object to check if the request has already been in cache.
Summarizing notes from meeting with Erjia
After talking with @msaroufim I wanted to take a stab at this implementation. Below is the prototype code I put together solely for the purpose of discussion. this uses the IterDatapipe
class but the map one is probably even simpler.
Once we figure out a good strategy for handling some of the design items below I can
def setup_redis_client(url, username=None, password=None):
# TODO Implement with ACL support
# TODO Implement with TLS support?
# TODO Redis cluster client??
try:
import redis
return redis.Redis(url)
except ImportError as e:
print("Redis needs to be installed in order to use Redis cache for datapipes")
raise
@functional_datapipe("redis_cache")
class RedisCacheHolderIterDataPipe(IterDataPipe[T_co]):
def __init__(self, source_dp: IterDataPipe[T_co], redis_url: str, cached_elements: Optional[int] = None) -> None:
self.source_dp: IterDataPipe[T_co] = source_dp
self._client = setup_redis_client(redis_url)
self._key = "tpipe"
self._start_idx = 0
# use number of cached elements rather than cache size
# avoids problem of using Redis DB size when Redis being used for
# more than just a datapipe cache
self.cached_elements = cached_elements
def _iter_stored(self):
# index always starts at 0 for redis list
# _start_index solely for tracking number of stored elements
for idx in range(0, self._cache_list_len()):
# LRANGE? or Pipeline??
yield self._deserialize(self._client.lindex(self._key, idx))
def _deserialize(self, response):
return pickle.loads(response)
def _serialize(self, value):
# TODO store datatype in Redis upon init? how to assert datatype
# dont serialize for primiative datatypes? only collections?
return pickle.dumps(value)
def __iter__(self) -> Iterator[T_co]:
if self._cache_list_len() > 1:
for idx, data in enumerate(self.source_dp):
print(data)
if idx < self._start_idx:
yield data
else:
break
yield from self._iter_stored()
else:
for data in self.source_dp:
self._client.rpush(self._key, self._serialize(data))
# Cache reaches element limit
if self.cached_elements is not None and self._cache_list_len() > self.cached_elements:
self._client.lpop(self._key)
self._start_idx += 1
yield data
def __contains__(self, key):
return self._client.exists(key)
def _cache_list_len(self):
return self._client.llen(self._key)
def __len__(self) -> int:
try:
return len(self.source_dp)
except TypeError:
# if list has been created in the database
if self._key in self:
return self._start_idx + self._cache_list_len()
else:
raise TypeError(f"{type(self).__name__} instance doesn't have valid length until the cache is loaded.")
When running the following simple example
from torchdata.datapipes.iter import IterableWrapper
source_dp = IterableWrapper(range(10))
cache_dp = source_dp.redis_cache(redis_url="localhost")
print(list(cache_dp))
I get the expected answer
>>> [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
I've also tested with the ag_news
example augmented to use the redis cache
# Stack CSV Parser directly on top of web-stream
dp = HttpReader([URL[split]]).parse_csv()
cache_dp = dp.redis_cache(redis_url="localhost")
return cache_dp.map(_process_tuple)
A couple design points to mention
cached_elements
instead of size
since it'll be tough to directly correlate list item size to cache size unless that is tracked separately. This provides an optional layer on top of the existing eviction strategies within Redis.start_idx
is used in the case where cached_elements
is triggered. This simply assists with the fact that redis lists always start with 0.__len__
function as this is my first pass at this library so please chime in if that doesn't look right. pickle
above, but that is not the ideal library to use here primarily for security reasons. Other options? Protobuf? Dill?I'll keep chipping away at this, but wanted to post early to gather feedback.
Hi @Spartee thank you for your patience here are my thoughts
Overall I think a PR doing what you describe should be pretty easy to merge with the value being that a Redis cache is yet another data source that anyone could leverage
Regarding some of your specific design questions
__len__
: for an iterable pipe I'd just raise an error https://github.com/pytorch/data/blob/main/torchdata/datapipes/iter/load/huggingface.py#L79-L80 otherwise return the cache lengthOn the future design points
@Spartee Thank you for putting up a prototype! Here are my thoughts.
Aside from optimize retrieval
, we might be able to provide multiple-layer cache
to reduce cach miss. For redis
example, cached_elements
is used to define the cache size. When we runs out of cache
, we might fall back to the another layer of cache
(maybe by default use a on-disk file - this might need the format discussed with @msaroufim ).
client init
Noob question about how redis
client/server works in python. Let's say we have multiple processes running the same pipeline, would each redis
client attach to the same server? If that's the case, it seems like we need to figure out a way to make sure the order of data is preserved because normally the data should be in round-robin order. (Process 0: [0, 2, 4, 6, ...], Process 1: [1, 3, 5, 7, ...])
Serialization
We currently still rely on pickle
as pytorch core is still depending on pickle
. But, cache
itself could implement their own serialization strategy inside __getstate__
and __setstate__
to eliminate the security issue. And, within each DataPipe, pickle
will invoke __getstate__
function and the inner serialization logic can be called before pickling the cache data.
Happy to! Some responses
cached_elements is a good idea, although I wonder is it common for people to use Redis caches for more than one use case or do folks typically have a cache per application?
I think it's most common to have one cache per application, but I have seen some places where a number of microservices use the same cache for a number of purposes (pubsub/brokering/caching). whatever we decide here should be well documented. I lean towards a more "hands off" approach here as the size checks would need to occur regularly on the client side which, under load, is not a negligible cost.
Serialization
Both valid options presented by @ejguan and @msaroufim. the out-of-band performance of pickle (protocol 5) is quite good and I know other pydata libraries use it. If we use dill, I would want to include that in a pip extra like torchdata[redis]
for consistency.
I'm going to play around with serialization/compression and come back with some options.
multiple-layer cache to reduce cach miss.
This is a strategy that I think should be handled on the server side. Some variants of Redis, esp managed ones like Redis Enterprise, give flash support for handing tiered caching. This is more performant than multiple requests on the client side. There are some OSS variants that provide tiered caching as well that are consistent with the OSS redis API.
client init
If we solely focus on multiprocessing settings, I would think the best solution would be to init a client within each process at startup and keep them alive until the object is destroyed. How would the processes be initilaized? user passes them in? Can you point me to any examples that use a single datapipe with multiple processes?
How would the processes be initilaized? user passes them in? Can you point me to any examples that use a single datapipe with multiple processes?
Here are some references. DataLoader2
would rely on MultiprocessingReadingService
to spawn processes.
https://github.com/pytorch/data/blob/86df1a09c0f649aca195a233508669de35b8623b/torchdata/dataloader2/reading_service.py#L147-L166
The datapipe should be automatically sharded based on the process number. Here would be a minimum example for multiprocessing.
from torchdata.datapipes.iter import IterableWrapper
from torchdata.dataloader2 import DataLoader2, PrototypeMultiProcessingReadingService
if __name__ == "__main__":
input_dp = IterableWrapper(list(range(100))
dp = input_dp.shuffler().sharding_filter()
rs = PrototypeMultiProcessingReadingService(num_workers=2)
dl = DataLoader2(dp, reading_service=rs)
for d in dl:
print(d)
🚀 The feature
The existing cacheholder leverages a python dictionary
But could instead provide a generic interface to plug in different cache providers like redis or memcached
Motivation, pitch
Python dictionaries have a few limitations when used as a cache
So the goal of this work would be to reduce memory overhead and cache misses of cache in multiprocessing environments while sacrificing latency on cache hit because a python dictionary will be faster to access than a remote KV store
The 0.4 release was very much about leveraging remote object stores so this work would follow that trend
Alternatives
No response
Additional context
Our queues right now use python lists https://github.com/pytorch/data/blob/main/torchdata/dataloader2/communication/queue.py#L11 but could instead leverage queues like Kafka or RabbitMQ so imagine a similar solution