Open smspillaz opened 3 months ago
I see, thanks for flagging. Yes, this is a known issue where we get credentials / metadata for each shard access. We have an eye on this and may pursue a permanent solution. Your current solution seems to be a possible route too.
@snarayan21, is there any temporary solution you have in mind? And what's the permanent solution you are thinking of? Is it wrapping all the download functions per cloud provider in a class?
@rishabhm12 @smspillaz Wondering, if anyone of you is interested in making a contribution?
Environment
To reproduce
Steps to reproduce the behavior:
torchrun
to launch multiple processes (eg, 4)StreamingDataset
in aDataLoader
with many worker processes (eg, 4)Because https://github.com/mosaicml/streaming/blob/main/streaming/base/storage/download.py#L235 queries the metadata service every time it is invoked in order to get credentials, doing this from multiple sub-processes at the same time can overload the service and exhaust the available connections, resulting in this warning:
Backoff inside of
google-auth
doesn't appear to add any jitter (its just exponential), so if the worker subprocesses are running roughly synchronized, then this eventually fails, even if we increase the timeout as specified here.In principle we should not have to query the metadata service all the time to get credentials. They are short-lived, but the
google.auth.compute_engine.GCECredentials
object provides anexpired
property (https://google-auth.readthedocs.io/en/master/reference/google.auth.compute_engine.html#module-google.auth.compute_engine). So it should be possible to cache the retrieved credentials for a given project ID and only refresh them when needed. In our case, we are monkey-patching the function to do the same thing:Expected behavior
Shards can be fetched from GCS and too many concurrent queries are not made to the metadata service. Probably the fix here is to somehow cache and refresh the credentials in the same way, though its unclear to me where the caching should happen.