observingClouds / xbitinfo

Python wrapper of BitInformation.jl to easily compress xarray datasets based on their information content
https://xbitinfo.readthedocs.io
MIT License
54 stars 22 forks source link

`MemoryError` when using `get_bitinformation` with the python implementation with high-resolution dataset #223

Open ayoubft opened 1 year ago

ayoubft commented 1 year ago

Working with high resolution dataset : Dimensions: longitude: 24000; latitude: 12000; time: 1.

When I try to get_bitinformation using the python implementation it raises this error: MemoryError: Unable to allocate 8.58 GiB for an array with shape (287976000, 8, 4) and data type bool

PS: When reverting to the julia implementation it works without this error.

Full output
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
File ~/miniconda3/envs/bitinfo/lib/python3.11/site-packages/numpy/ma/core.py:714, in getdata(a, subok)
    713 try:
--> 714     data = a._data
    715 except AttributeError:

AttributeError: 'Array' object has no attribute '_data'

During handling of the above exception, another exception occurred:

MemoryError                               Traceback (most recent call last)
Cell In[10], line 2
      1 # get information content per bit
----> 2 info_per_bit = xb.get_bitinformation(ds, dim="latitude", implementation="python")

File ~/xbitinfo/xbitinfo/xbitinfo.py:236, in get_bitinformation(ds, dim, axis, label, overwrite, implementation, **kwargs)
    234         info_per_bit[var] = info_per_bit_var
    235 elif implementation == "python":
--> 236     info_per_bit_var = _py_get_bitinformation(ds, var, axis, dim, kwargs)
    237     if info_per_bit_var is None:
    238         continue

File ~/xbitinfo/xbitinfo/xbitinfo.py:308, in _py_get_bitinformation(ds, var, axis, dim, kwargs)
    306 info_per_bit = {}
    307 logging.info("Calling python implementation now")
--> 308 info_per_bit["bitinfo"] = pb.bitinformation(X, axis=axis).compute()
    309 info_per_bit["dim"] = dim
    310 info_per_bit["axis"] = axis

File ~/xbitinfo/xbitinfo/_py_bitinfo.py:160, in bitinformation(a, axis)
    156 sa = tuple(slice(0, -1) if i == axis else slice(None) for i in range(len(a.shape)))
    157 sb = tuple(
    158     slice(1, None) if i == axis else slice(None) for i in range(len(a.shape))
    159 )
--> 160 return mutual_information(a[sa], a[sb])

File ~/xbitinfo/xbitinfo/_py_bitinfo.py:151, in mutual_information(a, b, base)
    149 pr = p.sum(axis=-1)[..., np.newaxis]
    150 ps = p.sum(axis=-2)[..., np.newaxis, :]
--> 151 mutual_info = (p * np.ma.log(p / (pr * ps))).sum(axis=(-1, -2)) / np.log(base)
    152 return mutual_info

File ~/miniconda3/envs/bitinfo/lib/python3.11/site-packages/numpy/ma/core.py:933, in _MaskedUnaryOperation.__call__(self, a, *args, **kwargs)
    928 def __call__(self, a, *args, **kwargs):
    929     """
    930     Execute the call behavior.
    931 
    932     """
--> 933     d = getdata(a)
    934     # Deal with domain
    935     if self.domain is not None:
    936         # Case 1.1. : Domained function
    937         # nans at masked positions cause RuntimeWarnings, even though
    938         # they are masked. To avoid this we suppress warnings.

File ~/miniconda3/envs/bitinfo/lib/python3.11/site-packages/numpy/ma/core.py:716, in getdata(a, subok)
    714     data = a._data
    715 except AttributeError:
--> 716     data = np.array(a, copy=False, subok=subok)
    717 if not subok:
    718     return data.view(ndarray)

File ~/miniconda3/envs/bitinfo/lib/python3.11/site-packages/dask/array/core.py:1701, in Array.__array__(self, dtype, **kwargs)
   1700 def __array__(self, dtype=None, **kwargs):
-> 1701     x = self.compute()
   1702     if dtype and x.dtype != dtype:
   1703         x = x.astype(dtype)

File ~/miniconda3/envs/bitinfo/lib/python3.11/site-packages/dask/base.py:310, in DaskMethodsMixin.compute(self, **kwargs)
    286 def compute(self, **kwargs):
    287     """Compute this dask collection
    288 
    289     This turns a lazy Dask collection into its in-memory equivalent.
   (...)
    308     dask.compute
    309     """
--> 310     (result,) = compute(self, traverse=False, **kwargs)
    311     return result

File ~/miniconda3/envs/bitinfo/lib/python3.11/site-packages/dask/base.py:595, in compute(traverse, optimize_graph, scheduler, get, *args, **kwargs)
    592     keys.append(x.__dask_keys__())
    593     postcomputes.append(x.__dask_postcompute__())
--> 595 results = schedule(dsk, keys, **kwargs)
    596 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])

File ~/miniconda3/envs/bitinfo/lib/python3.11/site-packages/dask/threaded.py:89, in get(dsk, keys, cache, num_workers, pool, **kwargs)
     86     elif isinstance(pool, multiprocessing.pool.Pool):
     87         pool = MultiprocessingPoolExecutor(pool)
---> 89 results = get_async(
     90     pool.submit,
     91     pool._max_workers,
     92     dsk,
     93     keys,
     94     cache=cache,
     95     get_id=_thread_get_id,
     96     pack_exception=pack_exception,
     97     **kwargs,
     98 )
    100 # Cleanup pools associated to dead threads
    101 with pools_lock:

File ~/miniconda3/envs/bitinfo/lib/python3.11/site-packages/dask/local.py:511, in get_async(submit, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, chunksize, **kwargs)
    509         _execute_task(task, data)  # Re-execute locally
    510     else:
--> 511         raise_exception(exc, tb)
    512 res, worker_id = loads(res_info)
    513 state["cache"][key] = res

File ~/miniconda3/envs/bitinfo/lib/python3.11/site-packages/dask/local.py:319, in reraise(exc, tb)
    317 if exc.__traceback__ is not tb:
    318     raise exc.with_traceback(tb)
--> 319 raise exc

File ~/miniconda3/envs/bitinfo/lib/python3.11/site-packages/dask/local.py:224, in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
    222 try:
    223     task, data = loads(task_info)
--> 224     result = _execute_task(task, data)
    225     id = get_id()
    226     result = dumps((result, id))

File ~/miniconda3/envs/bitinfo/lib/python3.11/site-packages/dask/core.py:121, in _execute_task(arg, cache, dsk)
    117     func, args = arg[0], arg[1:]
    118     # Note: Don't assign the subtask results to a variable. numpy detects
    119     # temporaries by their reference count and can execute certain
    120     # operations in-place.
--> 121     return func(*(_execute_task(a, cache) for a in args))
    122 elif not ishashable(arg):
    123     return arg

File ~/miniconda3/envs/bitinfo/lib/python3.11/site-packages/dask/core.py:121, in (.0)
    117     func, args = arg[0], arg[1:]
    118     # Note: Don't assign the subtask results to a variable. numpy detects
    119     # temporaries by their reference count and can execute certain
    120     # operations in-place.
--> 121     return func(*(_execute_task(a, cache) for a in args))
    122 elif not ishashable(arg):
    123     return arg

File ~/miniconda3/envs/bitinfo/lib/python3.11/site-packages/dask/core.py:121, in _execute_task(arg, cache, dsk)
    117     func, args = arg[0], arg[1:]
    118     # Note: Don't assign the subtask results to a variable. numpy detects
    119     # temporaries by their reference count and can execute certain
    120     # operations in-place.
--> 121     return func(*(_execute_task(a, cache) for a in args))
    122 elif not ishashable(arg):
    123     return arg

File ~/miniconda3/envs/bitinfo/lib/python3.11/site-packages/dask/core.py:121, in (.0)
    117     func, args = arg[0], arg[1:]
    118     # Note: Don't assign the subtask results to a variable. numpy detects
    119     # temporaries by their reference count and can execute certain
    120     # operations in-place.
--> 121     return func(*(_execute_task(a, cache) for a in args))
    122 elif not ishashable(arg):
    123     return arg

File ~/miniconda3/envs/bitinfo/lib/python3.11/site-packages/dask/core.py:115, in _execute_task(arg, cache, dsk)
     85 """Do the actual work of collecting data and executing a function
     86 
     87 Examples
   (...)
    112 'foo'
    113 """
    114 if isinstance(arg, list):
--> 115     return [_execute_task(a, cache) for a in arg]
    116 elif istask(arg):
    117     func, args = arg[0], arg[1:]

File ~/miniconda3/envs/bitinfo/lib/python3.11/site-packages/dask/core.py:115, in (.0)
     85 """Do the actual work of collecting data and executing a function
     86 
     87 Examples
   (...)
    112 'foo'
    113 """
    114 if isinstance(arg, list):
--> 115     return [_execute_task(a, cache) for a in arg]
    116 elif istask(arg):
    117     func, args = arg[0], arg[1:]

File ~/miniconda3/envs/bitinfo/lib/python3.11/site-packages/dask/core.py:121, in _execute_task(arg, cache, dsk)
    117     func, args = arg[0], arg[1:]
    118     # Note: Don't assign the subtask results to a variable. numpy detects
    119     # temporaries by their reference count and can execute certain
    120     # operations in-place.
--> 121     return func(*(_execute_task(a, cache) for a in args))
    122 elif not ishashable(arg):
    123     return arg

File ~/miniconda3/envs/bitinfo/lib/python3.11/site-packages/dask/optimization.py:992, in SubgraphCallable.__call__(self, *args)
    990 if not len(args) == len(self.inkeys):
    991     raise ValueError("Expected %d args, got %d" % (len(self.inkeys), len(args)))
--> 992 return core.get(self.dsk, self.outkey, dict(zip(self.inkeys, args)))

File ~/miniconda3/envs/bitinfo/lib/python3.11/site-packages/dask/core.py:151, in get(dsk, out, cache)
    149 for key in toposort(dsk):
    150     task = dsk[key]
--> 151     result = _execute_task(task, cache)
    152     cache[key] = result
    153 result = _execute_task(out, cache)

File ~/miniconda3/envs/bitinfo/lib/python3.11/site-packages/dask/core.py:121, in _execute_task(arg, cache, dsk)
    117     func, args = arg[0], arg[1:]
    118     # Note: Don't assign the subtask results to a variable. numpy detects
    119     # temporaries by their reference count and can execute certain
    120     # operations in-place.
--> 121     return func(*(_execute_task(a, cache) for a in args))
    122 elif not ishashable(arg):
    123     return arg

File ~/miniconda3/envs/bitinfo/lib/python3.11/site-packages/dask/core.py:121, in (.0)
    117     func, args = arg[0], arg[1:]
    118     # Note: Don't assign the subtask results to a variable. numpy detects
    119     # temporaries by their reference count and can execute certain
    120     # operations in-place.
--> 121     return func(*(_execute_task(a, cache) for a in args))
    122 elif not ishashable(arg):
    123     return arg

File ~/miniconda3/envs/bitinfo/lib/python3.11/site-packages/dask/core.py:115, in _execute_task(arg, cache, dsk)
     85 """Do the actual work of collecting data and executing a function
     86 
     87 Examples
   (...)
    112 'foo'
    113 """
    114 if isinstance(arg, list):
--> 115     return [_execute_task(a, cache) for a in arg]
    116 elif istask(arg):
    117     func, args = arg[0], arg[1:]

File ~/miniconda3/envs/bitinfo/lib/python3.11/site-packages/dask/core.py:115, in (.0)
     85 """Do the actual work of collecting data and executing a function
     86 
     87 Examples
   (...)
    112 'foo'
    113 """
    114 if isinstance(arg, list):
--> 115     return [_execute_task(a, cache) for a in arg]
    116 elif istask(arg):
    117     func, args = arg[0], arg[1:]

File ~/miniconda3/envs/bitinfo/lib/python3.11/site-packages/dask/core.py:121, in _execute_task(arg, cache, dsk)
    117     func, args = arg[0], arg[1:]
    118     # Note: Don't assign the subtask results to a variable. numpy detects
    119     # temporaries by their reference count and can execute certain
    120     # operations in-place.
--> 121     return func(*(_execute_task(a, cache) for a in args))
    122 elif not ishashable(arg):
    123     return arg

File ~/miniconda3/envs/bitinfo/lib/python3.11/site-packages/dask/core.py:121, in (.0)
    117     func, args = arg[0], arg[1:]
    118     # Note: Don't assign the subtask results to a variable. numpy detects
    119     # temporaries by their reference count and can execute certain
    120     # operations in-place.
--> 121     return func(*(_execute_task(a, cache) for a in args))
    122 elif not ishashable(arg):
    123     return arg

File ~/miniconda3/envs/bitinfo/lib/python3.11/site-packages/dask/core.py:121, in _execute_task(arg, cache, dsk)
    117     func, args = arg[0], arg[1:]
    118     # Note: Don't assign the subtask results to a variable. numpy detects
    119     # temporaries by their reference count and can execute certain
    120     # operations in-place.
--> 121     return func(*(_execute_task(a, cache) for a in args))
    122 elif not ishashable(arg):
    123     return arg

MemoryError: Unable to allocate 8.58 GiB for an array with shape (287976000, 8, 4) and data type bool
observingClouds commented 1 year ago

Thanks @ayoubft! Could you provide a minimal code snippet and a link to the dataset you are using? This would be of great help. Thanks.

milankl commented 1 year ago

The error originates from here

mutual_info = (p * np.ma.log(p / (pr * ps))).sum(axis=(-1, -2)) / np.log(base)

in https://github.com/observingClouds/xbitinfo/blob/0d8852b5b7f3493dcd8a8bc88ec0dd97feb90dff/xbitinfo/_py_bitinfo.py#L143-L152

which confuses me because while a,b should scale with the size of the data (they should be the 1:, :-1 non-allocating views on the actual data array), counts, p, pr, ps shouldn't!!! They should be of size nbits x 4 (for every bit position a 2x2 joint probability matrix) @ayoubft maybe you could check the size of these arrays? @observingClouds could you clarify whether there's a lazy evaluation triggered in this line?

milankl commented 1 year ago

Maybe related

https://github.com/observingClouds/xbitinfo/blob/0d8852b5b7f3493dcd8a8bc88ec0dd97feb90dff/xbitinfo/_py_bitinfo.py#L128-L139

this seems to have an outer loop over the number of bits then an inner loop over all elements in the data. Which means that I suspect (a >> s).astype("u1") to be allocating an entire copy of the array! In BitInformation.jl I do this therefore the other way around: Loop over ever element pair in the data and then inner loop over the bits. This is non-allocating.

milankl commented 1 year ago

This issue sounds weirdly familiar, and indeed we discussed this already at the beginning of this year: https://github.com/observingClouds/xbitinfo/pull/156#issuecomment-1424618296

ayoubft commented 1 year ago

For the dataset, I am using this one, (but I will need to check if it can be shared): image

The code snippet is the following:

path_to_data = 'data/netcdf/ecmwf_hs3g_20181101_msl.nc' 
info_per_bit = xb.get_bitinformation(ds, dim="latitude", implementation="python")

And it raises the error above.

observingClouds commented 1 year ago

Thanks @ayoubft! No worries with regards to sharing the dataset. I'll find one myself.