Hello! I noticed an issue with loading a cache using the load_or_create_cache() function in UnifiedDataset. In the following code snippet, you can see the keep_ids is not defined when cache_path exists, as is the case when a cache has already been created and you are just trying to load it. After the if statement, it is expected that keep_ids exists in order to remove the undesired entries from the data index. See the line self.remove_elements(keep_ids=keep_ids). It seems like keep_mask, which is the one of the outputs of dill.load(f,encoding="latin1"), should be renamed to keep_ids in order to fix this issue.
def load_or_create_cache(
self, cache_path: str, num_workers=0, filter_fn=None
) -> None:
if isfile(cache_path):
print(f"Loading cache from {cache_path} ...", end="")
t = time.time()
with open(cache_path, "rb") as f:
self._cached_batch_elements, keep_mask = dill.load(f, encoding="latin1")
print(f" done in {time.time() - t:.1f}s.")
else:
# Build cache
cached_batch_elements = []
keep_ids = []
if num_workers <= 0:
cache_data_iterator = self
else:
# Use DataLoader as a generic multiprocessing framework.
# We set batchsize=1 and a custom collate function.
# In effect this will just call self.__getitem__ in parallel.
cache_data_iterator = DataLoader(
self,
batch_size=1,
num_workers=num_workers,
shuffle=False,
collate_fn=lambda xlist: xlist[0],
)
for element in tqdm(
cache_data_iterator,
desc=f"Caching batch elements ({num_workers} CPUs): ",
disable=False,
):
if filter_fn is None or filter_fn(element):
cached_batch_elements.append(element)
keep_ids.append(element.data_index)
# Just deletes the variable cache_data_iterator,
# not self (in case it is set to that)!
del cache_data_iterator
print(f"Saving cache to {cache_path} ....", end="")
t = time.time()
with open(cache_path, "wb") as f:
dill.dump((cached_batch_elements, keep_ids), f)
print(f" done in {time.time() - t:.1f}s.")
self._cached_batch_elements = cached_batch_elements
# Remove unwanted elements
self.remove_elements(keep_ids=keep_ids)
# Verify
if len(self._cached_batch_elements) != self._data_len:
raise ValueError("Current data and cached data lengths do not match!")
Hello! I noticed an issue with loading a cache using the
load_or_create_cache()
function inUnifiedDataset
. In the following code snippet, you can see thekeep_ids
is not defined whencache_path
exists, as is the case when a cache has already been created and you are just trying to load it. After theif
statement, it is expected thatkeep_ids
exists in order to remove the undesired entries from the data index. See the lineself.remove_elements(keep_ids=keep_ids)
. It seems likekeep_mask
, which is the one of the outputs ofdill.load(f,encoding="latin1")
, should be renamed tokeep_ids
in order to fix this issue.