lmcinnes / umap

Uniform Manifold Approximation and Projection
BSD 3-Clause "New" or "Revised" License
7.45k stars 808 forks source link

Fit on array of size >= 4096 returns _TracedPicklingError: Failed in nopython mode pipeline (step: nopython mode backend) #547

Open milmin opened 3 years ago

milmin commented 3 years ago

A simple fit of umap on a random numpy array of size 4096 (or more) fails with the traceback detailed below. If the array is of size less than 4096 everything goes fine. What's going wrong? A very similar issue: https://github.com/lmcinnes/umap/issues/477

Framework: Spark-based environment with umap==0.5.0, numba==0.52.0, pynndescent==0.5.1, scipy==1.4.1

Minimal example:

import umap
import numpy as np
reducer = umap.UMAP() 
np.random.seed(0)
reducer.fit(np.random.rand(4096,16))

This gives the following error:

---------------------------------------------------------------------------
PicklingError                             Traceback (most recent call last)
/databricks/python/lib/python3.7/site-packages/numba/core/serialize.py in save(self, obj)
    304         try:
--> 305             return super().save(obj)
    306         except _TracedPicklingError:

/databricks/python/lib/python3.7/pickle.py in save(self, obj, save_persistent_id)
    503         if f is not None:
--> 504             f(self, obj) # Call unbound method with explicit self
    505             return

/databricks/python/lib/python3.7/pickle.py in save_type(self, obj)
   1015             return self.save_reduce(type, (...,), obj=obj)
-> 1016         return self.save_global(obj)
   1017 

/databricks/python/lib/python3.7/pickle.py in save_global(self, obj, name)
    959                 "Can't pickle %r: it's not found as %s.%s" %
--> 960                 (obj, module_name, name)) from None
    961         else:

PicklingError: Can't pickle <class 'collections.FlatTree'>: it's not found as collections.FlatTree

During handling of the above exception, another exception occurred:

_TracedPicklingError                      Traceback (most recent call last)
<command-103842> in <module>
      3 reducer = umap.UMAP()
      4 np.random.seed(0)
----> 5 reducer.fit(np.random.rand(4096,16))

/databricks/python/lib/python3.7/site-packages/umap/umap_.py in fit(self, X, y)
   2387                 use_pynndescent=True,
   2388                 n_jobs=self.n_jobs,
-> 2389                 verbose=self.verbose,
   2390             )
   2391 

/databricks/python/lib/python3.7/site-packages/umap/umap_.py in nearest_neighbors(X, n_neighbors, metric, metric_kwds, angular, random_state, low_memory, use_pynndescent, n_jobs, verbose)
    337             low_memory=low_memory,
    338             n_jobs=n_jobs,
--> 339             verbose=verbose,
    340         )
    341         knn_indices, knn_dists = knn_search_index.neighbor_graph

/databricks/python/lib/python3.7/site-packages/pynndescent/pynndescent_.py in __init__(self, data, metric, metric_kwds, n_neighbors, n_trees, leaf_size, pruning_degree_multiplier, diversify_prob, n_search_trees, tree_init, init_graph, random_state, low_memory, max_candidates, n_iters, delta, n_jobs, compressed, verbose)
    789                 current_random_state,
    790                 self.n_jobs,
--> 791                 self._angular_trees,
    792             )
    793             leaf_array = rptree_leaf_array(self._rp_forest)

/databricks/python/lib/python3.7/site-packages/pynndescent/rp_trees.py in make_forest(data, n_neighbors, n_trees, leaf_size, rng_state, random_state, n_jobs, angular)
    999             result = joblib.Parallel(n_jobs=n_jobs, prefer="threads")(
   1000                 joblib.delayed(make_dense_tree)(data, rng_states[i], leaf_size, angular)
-> 1001                 for i in range(n_trees)
   1002             )
   1003     except (RuntimeError, RecursionError, SystemError):

/databricks/python/lib/python3.7/site-packages/joblib/parallel.py in __call__(self, iterable)
   1015 
   1016             with self._backend.retrieval_context():
-> 1017                 self.retrieve()
   1018             # Make sure that we get a last message telling us we are done
   1019             elapsed_time = time.time() - self._start_time

/databricks/python/lib/python3.7/site-packages/joblib/parallel.py in retrieve(self)
    907             try:
    908                 if getattr(self._backend, 'supports_timeout', False):
--> 909                     self._output.extend(job.get(timeout=self.timeout))
    910                 else:
    911                     self._output.extend(job.get())

/databricks/python/lib/python3.7/multiprocessing/pool.py in get(self, timeout)
    655             return self._value
    656         else:
--> 657             raise self._value
    658 
    659     def _set(self, i, obj):

/databricks/python/lib/python3.7/multiprocessing/pool.py in worker(inqueue, outqueue, initializer, initargs, maxtasks, wrap_exception)
    119         job, i, func, args, kwds = task
    120         try:
--> 121             result = (True, func(*args, **kwds))
    122         except Exception as e:
    123             if wrap_exception and func is not _helper_reraises_exception:

/databricks/python/lib/python3.7/site-packages/joblib/_parallel_backends.py in __call__(self, *args, **kwargs)
    606     def __call__(self, *args, **kwargs):
    607         try:
--> 608             return self.func(*args, **kwargs)
    609         except KeyboardInterrupt:
    610             # We capture the KeyboardInterrupt and reraise it as

/databricks/python/lib/python3.7/site-packages/joblib/parallel.py in __call__(self)
    254         with parallel_backend(self._backend, n_jobs=self._n_jobs):
    255             return [func(*args, **kwargs)
--> 256                     for func, args, kwargs in self.items]
    257 
    258     def __len__(self):

/databricks/python/lib/python3.7/site-packages/joblib/parallel.py in <listcomp>(.0)
    254         with parallel_backend(self._backend, n_jobs=self._n_jobs):
    255             return [func(*args, **kwargs)
--> 256                     for func, args, kwargs in self.items]
    257 
    258     def __len__(self):

/databricks/python/lib/python3.7/site-packages/numba/core/dispatcher.py in _compile_for_args(self, *args, **kws)
    431                     e.patch_message('\n'.join((str(e).rstrip(), help_msg)))
    432             # ignore the FULL_TRACEBACKS config, this needs reporting!
--> 433             raise e
    434 
    435     def inspect_llvm(self, signature=None):

/databricks/python/lib/python3.7/site-packages/numba/core/dispatcher.py in _compile_for_args(self, *args, **kws)
    364                 argtypes.append(self.typeof_pyval(a))
    365         try:
--> 366             return self.compile(tuple(argtypes))
    367         except errors.ForceLiteralArg as e:
    368             # Received request for compiler re-entry with the list of arguments

/databricks/python/lib/python3.7/site-packages/numba/core/compiler_lock.py in _acquire_compile_lock(*args, **kwargs)
     30         def _acquire_compile_lock(*args, **kwargs):
     31             with self:
---> 32                 return func(*args, **kwargs)
     33         return _acquire_compile_lock
     34 

/databricks/python/lib/python3.7/site-packages/numba/core/dispatcher.py in compile(self, sig)
    855             self._cache_misses[sig] += 1
    856             try:
--> 857                 cres = self._compiler.compile(args, return_type)
    858             except errors.ForceLiteralArg as e:
    859                 def folded(args, kws):

/databricks/python/lib/python3.7/site-packages/numba/core/dispatcher.py in compile(self, args, return_type)
     75 
     76     def compile(self, args, return_type):
---> 77         status, retval = self._compile_cached(args, return_type)
     78         if status:
     79             return retval

/databricks/python/lib/python3.7/site-packages/numba/core/dispatcher.py in _compile_cached(self, args, return_type)
     89 
     90         try:
---> 91             retval = self._compile_core(args, return_type)
     92         except errors.TypingError as e:
     93             self._failed_cache[key] = e

/databricks/python/lib/python3.7/site-packages/numba/core/dispatcher.py in _compile_core(self, args, return_type)
    107                                       args=args, return_type=return_type,
    108                                       flags=flags, locals=self.locals,
--> 109                                       pipeline_class=self.pipeline_class)
    110         # Check typing error if object mode is used
    111         if cres.typing_error is not None and not flags.enable_pyobject:

/databricks/python/lib/python3.7/site-packages/numba/core/compiler.py in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
    600     pipeline = pipeline_class(typingctx, targetctx, library,
    601                               args, return_type, flags, locals)
--> 602     return pipeline.compile_extra(func)
    603 
    604 

/databricks/python/lib/python3.7/site-packages/numba/core/compiler.py in compile_extra(self, func)
    350         self.state.lifted = ()
    351         self.state.lifted_from = None
--> 352         return self._compile_bytecode()
    353 
    354     def compile_ir(self, func_ir, lifted=(), lifted_from=None):

/databricks/python/lib/python3.7/site-packages/numba/core/compiler.py in _compile_bytecode(self)
    412         """
    413         assert self.state.func_ir is None
--> 414         return self._compile_core()
    415 
    416     def _compile_ir(self):

/databricks/python/lib/python3.7/site-packages/numba/core/compiler.py in _compile_core(self)
    392                 self.state.status.fail_reason = e
    393                 if is_final_pipeline:
--> 394                     raise e
    395         else:
    396             raise CompilerError("All available pipelines exhausted")

/databricks/python/lib/python3.7/site-packages/numba/core/compiler.py in _compile_core(self)
    383             res = None
    384             try:
--> 385                 pm.run(self.state)
    386                 if self.state.cr is not None:
    387                     break

/databricks/python/lib/python3.7/site-packages/numba/core/compiler_machinery.py in run(self, state)
    337                     (self.pipeline_name, pass_desc)
    338                 patched_exception = self._patch_error(msg, e)
--> 339                 raise patched_exception
    340 
    341     def dependency_analysis(self):

/databricks/python/lib/python3.7/site-packages/numba/core/compiler_machinery.py in run(self, state)
    328                 pass_inst = _pass_registry.get(pss).pass_inst
    329                 if isinstance(pass_inst, CompilerPass):
--> 330                     self._runPass(idx, pass_inst, state)
    331                 else:
    332                     raise BaseException("Legacy pass in use")

/databricks/python/lib/python3.7/site-packages/numba/core/compiler_lock.py in _acquire_compile_lock(*args, **kwargs)
     30         def _acquire_compile_lock(*args, **kwargs):
     31             with self:
---> 32                 return func(*args, **kwargs)
     33         return _acquire_compile_lock
     34 

/databricks/python/lib/python3.7/site-packages/numba/core/compiler_machinery.py in _runPass(self, index, pss, internal_state)
    287             mutated |= check(pss.run_initialization, internal_state)
    288         with SimpleTimer() as pass_time:
--> 289             mutated |= check(pss.run_pass, internal_state)
    290         with SimpleTimer() as finalize_time:
    291             mutated |= check(pss.run_finalizer, internal_state)

/databricks/python/lib/python3.7/site-packages/numba/core/compiler_machinery.py in check(func, compiler_state)
    260 
    261         def check(func, compiler_state):
--> 262             mangled = func(compiler_state)
    263             if mangled not in (True, False):
    264                 msg = ("CompilerPass implementations should return True/False. "

/databricks/python/lib/python3.7/site-packages/numba/core/typed_passes.py in run_pass(self, state)
    447 
    448         # TODO: Pull this out into the pipeline
--> 449         NativeLowering().run_pass(state)
    450         lowered = state['cr']
    451         signature = typing.signature(state.return_type, *state.args)

/databricks/python/lib/python3.7/site-packages/numba/core/typed_passes.py in run_pass(self, state)
    373                 lower.lower()
    374                 if not flags.no_cpython_wrapper:
--> 375                     lower.create_cpython_wrapper(flags.release_gil)
    376 
    377                 if not flags.no_cfunc_wrapper:

/databricks/python/lib/python3.7/site-packages/numba/core/lowering.py in create_cpython_wrapper(self, release_gil)
    242         self.context.create_cpython_wrapper(self.library, self.fndesc,
    243                                             self.env, self.call_helper,
--> 244                                             release_gil=release_gil)
    245 
    246     def create_cfunc_wrapper(self):

/databricks/python/lib/python3.7/site-packages/numba/core/cpu.py in create_cpython_wrapper(self, library, fndesc, env, call_helper, release_gil)
    160                                 fndesc, env, call_helper=call_helper,
    161                                 release_gil=release_gil)
--> 162         builder.build()
    163         library.add_ir_module(wrapper_module)
    164 

/databricks/python/lib/python3.7/site-packages/numba/core/callwrapper.py in build(self)
    120 
    121         api = self.context.get_python_api(builder)
--> 122         self.build_wrapper(api, builder, closure, args, kws)
    123 
    124         return wrapper, api

/databricks/python/lib/python3.7/site-packages/numba/core/callwrapper.py in build_wrapper(self, api, builder, closure, args, kws)
    185 
    186             retty = self._simplified_return_type()
--> 187             obj = api.from_native_return(retty, retval, env_manager)
    188             builder.ret(obj)
    189 

/databricks/python/lib/python3.7/site-packages/numba/core/pythonapi.py in from_native_return(self, typ, val, env_manager)
   1387                                                     "prevented the return of " \
   1388                                                     "optional value"
-> 1389         out = self.from_native_value(typ, val, env_manager)
   1390         return out
   1391 

/databricks/python/lib/python3.7/site-packages/numba/core/pythonapi.py in from_native_value(self, typ, val, env_manager)
   1401 
   1402         c = _BoxContext(self.context, self.builder, self, env_manager)
-> 1403         return impl(typ, val, c)
   1404 
   1405     def reflect_native_value(self, typ, val, env_manager=None):

/databricks/python/lib/python3.7/site-packages/numba/core/boxing.py in box_namedtuple(typ, val, c)
    502     Convert native array or structure *val* to a namedtuple object.
    503     """
--> 504     cls_obj = c.pyapi.unserialize(c.pyapi.serialize_object(typ.instance_class))
    505     tuple_obj = box_tuple(typ, val, c)
    506     obj = c.pyapi.call(cls_obj, tuple_obj)

/databricks/python/lib/python3.7/site-packages/numba/core/pythonapi.py in serialize_object(self, obj)
   1362             gv = self.module.__serialized[obj]
   1363         except KeyError:
-> 1364             struct = self.serialize_uncached(obj)
   1365             name = ".const.picklebuf.%s" % (id(obj) if config.DIFF_IR == 0 else "DIFF_IR")
   1366             gv = self.context.insert_unique_const(self.module, name, struct)

/databricks/python/lib/python3.7/site-packages/numba/core/pythonapi.py in serialize_uncached(self, obj)
   1333         """
   1334         # First make the array constant
-> 1335         data = serialize.dumps(obj)
   1336         assert len(data) < 2**31
   1337         name = ".const.pickledata.%s" % (id(obj) if config.DIFF_IR == 0 else "DIFF_IR")

/databricks/python/lib/python3.7/site-packages/numba/core/serialize.py in dumps(obj)
    166     with io.BytesIO() as buf:
    167         p = pickler(buf)
--> 168         p.dump(obj)
    169         pickled = buf.getvalue()
    170 

/databricks/python/lib/python3.7/pickle.py in dump(self, obj)
    435         if self.proto >= 4:
    436             self.framer.start_framing()
--> 437         self.save(obj)
    438         self.write(STOP)
    439         self.framer.end_framing()

/databricks/python/lib/python3.7/site-packages/numba/core/serialize.py in save(self, obj)
    312             m = (f"Failed to pickle because of\n  {type(e).__name__}: {e}"
    313                  f"\ntracing... \n{perline(self.__trace)}")
--> 314             raise _TracedPicklingError(m)
    315         finally:
    316             self.__trace.pop()

_TracedPicklingError: Failed in nopython mode pipeline (step: nopython mode backend)
Failed to pickle because of
  PicklingError: Can't pickle <class 'collections.FlatTree'>: it's not found as collections.FlatTree
tracing... 
 [0]: <class 'type'>: 94412180539632
Rasmitha23 commented 3 years ago

Hi, Is anybody able to resolve this? I am getting the same error in jupyter notebooks. I tried with both cosine and euclidean distance metrics.

Following the above closed issue, i have also defined cosine as a custom function and used numba. Still getting the same error.

Please help

oscarorti commented 3 years ago

Hello,

I am still facing this issue, do we know any workaround or when will the bugfix be released?

Thank you!

Osherz5 commented 3 years ago

Hello,

I am still facing this issue, do we know any workaround or when will the bugfix be released?

Thank you!

It's a problem in pynndescent which was fixed in the latest release, try and update that package to 0.5.4

oscarorti commented 3 years ago

Hello, I am still facing this issue, do we know any workaround or when will the bugfix be released? Thank you!

It's a problem in pynndescent which was fixed in the latest release, try and update that package to 0.5.4

Thank you, but... That's the version I have installed in my environment.

print(pynndescent.__version__) --> 0.5.4
print(umap.__version__) --> 0.5.1
asif7adil commented 2 years ago

Hi, I am facing this issue in pyspark environment using jupyter notebook. Any dataframe lesser than 4000 rows is working fine but as soon as the number increases this error pops up. Has anybody been able to resolve this? I am also using t-SNE for my single-cell data in pyspark environment which is working fine. I have tried downgrading umap-learn and pynndescent, using different versions and used the latest versions of both. Nothing is helping.

oscarorti commented 2 years ago

Hello 🙋‍♂️

I finally found a hack. It seems that there is some weird internal overriding that makes the collections.FlatTree class not pickeable.

# This is a hack to be able to use UMAP.fit_transform with more than 4095 samples.
# See links below:
# https://github.com/lmcinnes/umap/issues/477
# https://github.com/lmcinnes/umap/issues/547
import collections
import pynndescent

collections.namedtuple("n", [], module=__name__)
pynndescent.rp_trees.FlatTree.__module__ = "pynndescent.rp_trees"

I hope this could help you!

oscarorti commented 2 years ago

Just copy-paste below the module's imports.

asif7adil commented 2 years ago

Just copy-paste below the module's imports.

It is working like charm now. 😁 Thank you so much @oscarorti tried it on 5k pbmc for now, will keep the community updated for 68k