chanzuckerberg / cellxgene-census

CZ CELLxGENE Discover Census
https://chanzuckerberg.github.io/cellxgene-census/
MIT License
78 stars 20 forks source link

pytorch unit test hangs on Python 3.9 #521

Closed bkmartinjr closed 1 year ago

bkmartinjr commented 1 year ago

Describe the bug

the test_experiment_dataloader__multiprocess_pickling unit test will hang when run on Linux/Python 3.9. I have let it sit for 12+ hours with no change.

$ pytest ./api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py::test_experiment_dataloader__multiprocess_pickling --full-trace -v --experimental
=================================================================== test session starts ====================================================================
platform linux -- Python 3.9.16, pytest-7.3.1, pluggy-1.0.0 -- /home/bruce/cellxgene-census/venv/bin/python
cachedir: .pytest_cache
rootdir: /home/bruce/cellxgene-census/api/python/cellxgene_census
configfile: pyproject.toml
plugins: requests-mock-1.10.0
collected 1 item                                                                                                                                           

api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py::test_experiment_dataloader__multiprocess_pickling[3-3-X_layer_names0-pytorch_x_value_gen]

After a keyboard interrupt (with --full-trace enabled):

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! KeyboardInterrupt !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

config = <_pytest.config.Config object at 0x7fc2f52b1b50>, doit = <function _main at 0x7fc2f554e940>

    def wrap_session(
        config: Config, doit: Callable[[Config, "Session"], Optional[Union[int, ExitCode]]]
    ) -> Union[int, ExitCode]:
        """Skeleton command line program."""
        session = Session.from_config(config)
        session.exitstatus = ExitCode.OK
        initstate = 0
        try:
            try:
                config._do_configure()
                initstate = 1
                config.hook.pytest_sessionstart(session=session)
                initstate = 2
>               session.exitstatus = doit(config, session) or 0

venv/lib/python3.9/site-packages/_pytest/main.py:269: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

config = <_pytest.config.Config object at 0x7fc2f52b1b50>, session = <Session cellxgene_census exitstatus=<ExitCode.OK: 0> testsfailed=0 testscollected=1>

    def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]:
        """Default command line protocol for initialization, session,
        running tests and reporting."""
        config.hook.pytest_collection(session=session)
>       config.hook.pytest_runtestloop(session=session)

venv/lib/python3.9/site-packages/_pytest/main.py:323: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <_HookCaller 'pytest_runtestloop'>, args = ()
kwargs = {'session': <Session cellxgene_census exitstatus=<ExitCode.OK: 0> testsfailed=0 testscollected=1>}, argname = 'session', firstresult = True

    def __call__(self, *args, **kwargs):
        if args:
            raise TypeError("hook calling supports only keyword arguments")
        assert not self.is_historic()

        # This is written to avoid expensive operations when not needed.
        if self.spec:
            for argname in self.spec.argnames:
                if argname not in kwargs:
                    notincall = tuple(set(self.spec.argnames) - kwargs.keys())
                    warnings.warn(
                        "Argument(s) {} which are declared in the hookspec "
                        "can not be found in this hook call".format(notincall),
                        stacklevel=2,
                    )
                    break

            firstresult = self.spec.opts.get("firstresult")
        else:
            firstresult = False

>       return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)

venv/lib/python3.9/site-packages/pluggy/_hooks.py:265: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <_pytest.config.PytestPluginManager object at 0x7fc2f52c5370>, hook_name = 'pytest_runtestloop'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/home/bruce/cellxgene-census/venv/lib/python3.9/sit...t/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7fc2f5160a60>>]
kwargs = {'session': <Session cellxgene_census exitstatus=<ExitCode.OK: 0> testsfailed=0 testscollected=1>}, firstresult = True

    def _hookexec(self, hook_name, methods, kwargs, firstresult):
        # called from all hookcaller instances.
        # enable_tracing will set its own wrapping function at self._inner_hookexec
>       return self._inner_hookexec(hook_name, methods, kwargs, firstresult)

venv/lib/python3.9/site-packages/pluggy/_manager.py:80: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

session = <Session cellxgene_census exitstatus=<ExitCode.OK: 0> testsfailed=0 testscollected=1>

    def pytest_runtestloop(session: "Session") -> bool:
        if session.testsfailed and not session.config.option.continue_on_collection_errors:
            raise session.Interrupted(
                "%d error%s during collection"
                % (session.testsfailed, "s" if session.testsfailed != 1 else "")
            )

        if session.config.option.collectonly:
            return True

        for i, item in enumerate(session.items):
            nextitem = session.items[i + 1] if i + 1 < len(session.items) else None
>           item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem)

venv/lib/python3.9/site-packages/_pytest/main.py:348: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <_HookCaller 'pytest_runtest_protocol'>, args = ()
kwargs = {'item': <Function test_experiment_dataloader__multiprocess_pickling[3-3-X_layer_names0-pytorch_x_value_gen]>, 'nextitem': None}
argname = 'nextitem', firstresult = True

    def __call__(self, *args, **kwargs):
        if args:
            raise TypeError("hook calling supports only keyword arguments")
        assert not self.is_historic()

        # This is written to avoid expensive operations when not needed.
        if self.spec:
            for argname in self.spec.argnames:
                if argname not in kwargs:
                    notincall = tuple(set(self.spec.argnames) - kwargs.keys())
                    warnings.warn(
                        "Argument(s) {} which are declared in the hookspec "
                        "can not be found in this hook call".format(notincall),
                        stacklevel=2,
                    )
                    break

            firstresult = self.spec.opts.get("firstresult")
        else:
            firstresult = False

>       return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)

venv/lib/python3.9/site-packages/pluggy/_hooks.py:265: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <_pytest.config.PytestPluginManager object at 0x7fc2f52c5370>, hook_name = 'pytest_runtest_protocol'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/home/bruce/cellxgene-census/venv/lib/python3.9...=<module '_pytest.warnings' from '/home/bruce/cellxgene-census/venv/lib/python3.9/site-packages/_pytest/warnings.py'>>]
kwargs = {'item': <Function test_experiment_dataloader__multiprocess_pickling[3-3-X_layer_names0-pytorch_x_value_gen]>, 'nextitem': None}
firstresult = True

    def _hookexec(self, hook_name, methods, kwargs, firstresult):
        # called from all hookcaller instances.
        # enable_tracing will set its own wrapping function at self._inner_hookexec
>       return self._inner_hookexec(hook_name, methods, kwargs, firstresult)

venv/lib/python3.9/site-packages/pluggy/_manager.py:80: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

item = <Function test_experiment_dataloader__multiprocess_pickling[3-3-X_layer_names0-pytorch_x_value_gen]>, nextitem = None

    def pytest_runtest_protocol(item: Item, nextitem: Optional[Item]) -> bool:
        ihook = item.ihook
        ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location)
>       runtestprotocol(item, nextitem=nextitem)

venv/lib/python3.9/site-packages/_pytest/runner.py:114: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

item = <Function test_experiment_dataloader__multiprocess_pickling[3-3-X_layer_names0-pytorch_x_value_gen]>, log = True, nextitem = None

    def runtestprotocol(
        item: Item, log: bool = True, nextitem: Optional[Item] = None
    ) -> List[TestReport]:
        hasrequest = hasattr(item, "_request")
        if hasrequest and not item._request:  # type: ignore[attr-defined]
            # This only happens if the item is re-run, as is done by
            # pytest-rerunfailures.
            item._initrequest()  # type: ignore[attr-defined]
        rep = call_and_report(item, "setup", log)
        reports = [rep]
        if rep.passed:
            if item.config.getoption("setupshow", False):
                show_test_item(item)
            if not item.config.getoption("setuponly", False):
>               reports.append(call_and_report(item, "call", log))

venv/lib/python3.9/site-packages/_pytest/runner.py:133: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

item = <Function test_experiment_dataloader__multiprocess_pickling[3-3-X_layer_names0-pytorch_x_value_gen]>, when = 'call', log = True, kwds = {}

    def call_and_report(
        item: Item, when: "Literal['setup', 'call', 'teardown']", log: bool = True, **kwds
    ) -> TestReport:
>       call = call_runtest_hook(item, when, **kwds)

venv/lib/python3.9/site-packages/_pytest/runner.py:222: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

item = <Function test_experiment_dataloader__multiprocess_pickling[3-3-X_layer_names0-pytorch_x_value_gen]>, when = 'call', kwds = {}
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)

    def call_runtest_hook(
        item: Item, when: "Literal['setup', 'call', 'teardown']", **kwds
    ) -> "CallInfo[None]":
        if when == "setup":
            ihook: Callable[..., None] = item.ihook.pytest_runtest_setup
        elif when == "call":
            ihook = item.ihook.pytest_runtest_call
        elif when == "teardown":
            ihook = item.ihook.pytest_runtest_teardown
        else:
            assert False, f"Unhandled runtest hook case: {when}"
        reraise: Tuple[Type[BaseException], ...] = (Exit,)
        if not item.config.getoption("usepdb", False):
            reraise += (KeyboardInterrupt,)
>       return CallInfo.from_call(
            lambda: ihook(item=item, **kwds), when=when, reraise=reraise
        )

venv/lib/python3.9/site-packages/_pytest/runner.py:261: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

cls = <class '_pytest.runner.CallInfo'>, func = <function call_runtest_hook.<locals>.<lambda> at 0x7fc0c1027e50>, when = 'call'
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)

    @classmethod
    def from_call(
        cls,
        func: "Callable[[], TResult]",
        when: "Literal['collect', 'setup', 'call', 'teardown']",
        reraise: Optional[
            Union[Type[BaseException], Tuple[Type[BaseException], ...]]
        ] = None,
    ) -> "CallInfo[TResult]":
        """Call func, wrapping the result in a CallInfo.

        :param func:
            The function to call. Called without arguments.
        :param when:
            The phase in which the function is called.
        :param reraise:
            Exception or exceptions that shall propagate if raised by the
            function, instead of being wrapped in the CallInfo.
        """
        excinfo = None
        start = timing.time()
        precise_start = timing.perf_counter()
        try:
>           result: Optional[TResult] = func()

venv/lib/python3.9/site-packages/_pytest/runner.py:341: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

>       lambda: ihook(item=item, **kwds), when=when, reraise=reraise
    )

venv/lib/python3.9/site-packages/_pytest/runner.py:262: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <_HookCaller 'pytest_runtest_call'>, args = ()
kwargs = {'item': <Function test_experiment_dataloader__multiprocess_pickling[3-3-X_layer_names0-pytorch_x_value_gen]>}, argname = 'item'
firstresult = False

    def __call__(self, *args, **kwargs):
        if args:
            raise TypeError("hook calling supports only keyword arguments")
        assert not self.is_historic()

        # This is written to avoid expensive operations when not needed.
        if self.spec:
            for argname in self.spec.argnames:
                if argname not in kwargs:
                    notincall = tuple(set(self.spec.argnames) - kwargs.keys())
                    warnings.warn(
                        "Argument(s) {} which are declared in the hookspec "
                        "can not be found in this hook call".format(notincall),
                        stacklevel=2,
                    )
                    break

            firstresult = self.spec.opts.get("firstresult")
        else:
            firstresult = False

>       return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)

venv/lib/python3.9/site-packages/pluggy/_hooks.py:265: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <_pytest.config.PytestPluginManager object at 0x7fc2f52c5370>, hook_name = 'pytest_runtest_call'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/home/bruce/cellxgene-census/venv/lib/python3.9...est.threadexception' from '/home/bruce/cellxgene-census/venv/lib/python3.9/site-packages/_pytest/threadexception.py'>>]
kwargs = {'item': <Function test_experiment_dataloader__multiprocess_pickling[3-3-X_layer_names0-pytorch_x_value_gen]>}, firstresult = False

    def _hookexec(self, hook_name, methods, kwargs, firstresult):
        # called from all hookcaller instances.
        # enable_tracing will set its own wrapping function at self._inner_hookexec
>       return self._inner_hookexec(hook_name, methods, kwargs, firstresult)

venv/lib/python3.9/site-packages/pluggy/_manager.py:80: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

item = <Function test_experiment_dataloader__multiprocess_pickling[3-3-X_layer_names0-pytorch_x_value_gen]>

    def pytest_runtest_call(item: Item) -> None:
        _update_current_test_var(item, "call")
        try:
            del sys.last_type
            del sys.last_value
            del sys.last_traceback
        except AttributeError:
            pass
        try:
>           item.runtest()

venv/lib/python3.9/site-packages/_pytest/runner.py:169: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <Function test_experiment_dataloader__multiprocess_pickling[3-3-X_layer_names0-pytorch_x_value_gen]>

    def runtest(self) -> None:
        """Execute the underlying test function."""
>       self.ihook.pytest_pyfunc_call(pyfuncitem=self)

venv/lib/python3.9/site-packages/_pytest/python.py:1799: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <_HookCaller 'pytest_pyfunc_call'>, args = ()
kwargs = {'pyfuncitem': <Function test_experiment_dataloader__multiprocess_pickling[3-3-X_layer_names0-pytorch_x_value_gen]>}, argname = 'pyfuncitem'
firstresult = True

    def __call__(self, *args, **kwargs):
        if args:
            raise TypeError("hook calling supports only keyword arguments")
        assert not self.is_historic()

        # This is written to avoid expensive operations when not needed.
        if self.spec:
            for argname in self.spec.argnames:
                if argname not in kwargs:
                    notincall = tuple(set(self.spec.argnames) - kwargs.keys())
                    warnings.warn(
                        "Argument(s) {} which are declared in the hookspec "
                        "can not be found in this hook call".format(notincall),
                        stacklevel=2,
                    )
                    break

            firstresult = self.spec.opts.get("firstresult")
        else:
            firstresult = False

>       return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)

venv/lib/python3.9/site-packages/pluggy/_hooks.py:265: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <_pytest.config.PytestPluginManager object at 0x7fc2f52c5370>, hook_name = 'pytest_pyfunc_call'
methods = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/home/bruce/cellxgene-census/venv/lib/python3.9/site-packages/_pytest/python.py'>>]
kwargs = {'pyfuncitem': <Function test_experiment_dataloader__multiprocess_pickling[3-3-X_layer_names0-pytorch_x_value_gen]>}, firstresult = True

    def _hookexec(self, hook_name, methods, kwargs, firstresult):
        # called from all hookcaller instances.
        # enable_tracing will set its own wrapping function at self._inner_hookexec
>       return self._inner_hookexec(hook_name, methods, kwargs, firstresult)

venv/lib/python3.9/site-packages/pluggy/_manager.py:80: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

pyfuncitem = <Function test_experiment_dataloader__multiprocess_pickling[3-3-X_layer_names0-pytorch_x_value_gen]>

    @hookimpl(trylast=True)
    def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]:
        testfunction = pyfuncitem.obj
        if is_async_function(testfunction):
            async_warn_and_skip(pyfuncitem.nodeid)
        funcargs = pyfuncitem.funcargs
        testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
>       result = testfunction(**testargs)

venv/lib/python3.9/site-packages/_pytest/python.py:194: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

soma_experiment = <Experiment '/tmp/pytest-of-bruce/pytest-5/test_experiment_dataloader__mu0/exp' (open for 'r') (2 items)
    'ms': 'fi...xp/ms' (unopened)
    'obs': 'file:///tmp/pytest-of-bruce/pytest-5/test_experiment_dataloader__mu0/exp/obs' (unopened)>

    @pytest.mark.experimental
    # noinspection PyTestParametrized,DuplicatedCode
    @pytest.mark.parametrize("n_obs,n_vars,X_layer_names,X_value_gen", [(3, 3, ("raw",), pytorch_x_value_gen)])
    def test_experiment_dataloader__multiprocess_pickling(soma_experiment: Experiment) -> None:
        """
        If the DataPipe is accessed prior to multiprocessing (num_workers > 0), its internal _query will be
        initialized. But since it cannot be pickled, we must ensure it is ignored during pickling in multiprocessing mode.
        This test verifies the correct pickling behavior is in place.
        """

        dp = ExperimentDataPipe(
            soma_experiment,
            measurement_name="RNA",
            X_name="raw",
            obs_column_names=["label"],
        )
        dl = experiment_dataloader(dp, num_workers=2)
        dp.obs_encoders()  # trigger query building
>       row = next(iter(dl))  # trigger multiprocessing

api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py:368: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7fc0c076b8b0>

    def __next__(self) -> Any:
        with torch.autograd.profiler.record_function(self._profile_name):
            if self._sampler_iter is None:
                # TODO(https://github.com/pytorch/pytorch/issues/76750)
                self._reset()  # type: ignore[call-arg]
>           data = self._next_data()

venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py:633: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7fc0c076b8b0>

    def _next_data(self):
        while True:
            # If the worker responsible for `self._rcvd_idx` has already ended
            # and was unable to fulfill this task (due to exhausting an `IterableDataset`),
            # we try to advance `self._rcvd_idx` to find the next valid index.
            #
            # This part needs to run in the loop because both the `self._get_data()`
            # call and `_IterableDatasetStopIteration` check below can mark
            # extra worker(s) as dead.
            while self._rcvd_idx < self._send_idx:
                info = self._task_info[self._rcvd_idx]
                worker_id = info[0]
                if len(info) == 2 or self._workers_status[worker_id]:  # has data or is still active
                    break
                del self._task_info[self._rcvd_idx]
                self._rcvd_idx += 1
            else:
                # no valid `self._rcvd_idx` is found (i.e., didn't break)
                if not self._persistent_workers:
                    self._shutdown_workers()
                raise StopIteration

            # Now `self._rcvd_idx` is the batch index we want to fetch

            # Check if the next sample has already been generated
            if len(self._task_info[self._rcvd_idx]) == 2:
                data = self._task_info.pop(self._rcvd_idx)[1]
                return self._process_data(data)

            assert not self._shutdown and self._tasks_outstanding > 0
>           idx, data = self._get_data()

venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py:1328: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7fc0c076b8b0>

    def _get_data(self):
        # Fetches data from `self._data_queue`.
        #
        # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
        # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`
        # in a loop. This is the only mechanism to detect worker failures for
        # Windows. For other platforms, a SIGCHLD handler is also used for
        # worker failure detection.
        #
        # If `pin_memory=True`, we also need check if `pin_memory_thread` had
        # died at timeouts.
        if self._timeout > 0:
            success, data = self._try_get_data(self._timeout)
            if success:
                return data
            else:
                raise RuntimeError('DataLoader timed out after {} seconds'.format(self._timeout))
        elif self._pin_memory:
            while self._pin_memory_thread.is_alive():
                success, data = self._try_get_data()
                if success:
                    return data
            else:
                # while condition is false, i.e., pin_memory_thread died.
                raise RuntimeError('Pin memory thread exited unexpectedly')
            # In this case, `self._data_queue` is a `queue.Queue`,. But we don't
            # need to call `.task_done()` because we don't use `.join()`.
        else:
            while True:
>               success, data = self._try_get_data()

venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py:1294: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7fc0c076b8b0>, timeout = 5.0

    def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
        # Tries to fetch data from `self._data_queue` once for a given timeout.
        # This can also be used as inner loop of fetching without timeout, with
        # the sender status as the loop condition.
        #
        # This raises a `RuntimeError` if any worker died expectedly. This error
        # can come from either the SIGCHLD handler in `_utils/signal_handling.py`
        # (only for non-Windows platforms), or the manual check below on errors
        # and timeouts.
        #
        # Returns a 2-tuple:
        #   (bool: whether successfully get data, any: data if successful else None)
        try:
>           data = self._data_queue.get(timeout=timeout)

venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py:1132: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <multiprocessing.queues.Queue object at 0x7fc0c076b970>, block = True, timeout = 4.999992875033058

    def get(self, block=True, timeout=None):
        if self._closed:
            raise ValueError(f"Queue {self!r} is closed")
        if block and timeout is None:
            with self._rlock:
                res = self._recv_bytes()
            self._sem.release()
        else:
            if block:
                deadline = time.monotonic() + timeout
            if not self._rlock.acquire(block, timeout):
                raise Empty
            try:
                if block:
                    timeout = deadline - time.monotonic()
>                   if not self._poll(timeout):

/usr/lib/python3.9/multiprocessing/queues.py:113: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <multiprocessing.connection.Connection object at 0x7fc0c076b940>, timeout = 4.999992875033058

    def poll(self, timeout=0.0):
        """Whether there is any input available to be read"""
        self._check_closed()
        self._check_readable()
>       return self._poll(timeout)

/usr/lib/python3.9/multiprocessing/connection.py:257: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <multiprocessing.connection.Connection object at 0x7fc0c076b940>, timeout = 4.999992875033058

    def _poll(self, timeout):
>       r = wait([self], timeout)

/usr/lib/python3.9/multiprocessing/connection.py:424: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

object_list = [<multiprocessing.connection.Connection object at 0x7fc0c076b940>], timeout = 4.999992875033058

    def wait(object_list, timeout=None):
        '''
        Wait till an object in object_list is ready/readable.

        Returns list of those objects in object_list which are ready/readable.
        '''
        with _WaitSelector() as selector:
            for obj in object_list:
                selector.register(obj, selectors.EVENT_READ)

            if timeout is not None:
                deadline = time.monotonic() + timeout

            while True:
>               ready = selector.select(timeout)

/usr/lib/python3.9/multiprocessing/connection.py:931: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <selectors.PollSelector object at 0x7fc0c060f400>, timeout = 5000

    def select(self, timeout=None):
        # This is shared between poll() and epoll().
        # epoll() has a different signature and handling of timeout parameter.
        if timeout is None:
            timeout = None
        elif timeout <= 0:
            timeout = 0
        else:
            # poll() has a resolution of 1 millisecond, round away from
            # zero to wait *at least* timeout seconds.
            timeout = math.ceil(timeout * 1e3)
        ready = []
        try:
>           fd_event_list = self._selector.poll(timeout)
E           KeyboardInterrupt

/usr/lib/python3.9/selectors.py:416: KeyboardInterrupt
============================================================ no tests ran in 258.25s (0:04:18) =============================================================

Environment

From tiledbsoma.show_package_versions():

tiledbsoma.__version__        1.2.5
TileDB-Py tiledb.version()    (0, 21, 3)
TileDB core version           2.15.2
libtiledbsoma version()       libtiledb=2.15.2
python version                3.9.16.final.0
OS version                    Linux 5.15.0-1037-aws
$ pip freeze
aiobotocore==2.4.2
aiohttp==3.8.4
aioitertools==0.11.0
aiosignal==1.3.1
anndata==0.8.0
asttokens==2.2.1
async-timeout==4.0.2
attrs==23.1.0
autopep8==2.0.2
backcall==0.2.0
black==23.3.0
bleach==6.0.0
botocore==1.27.59
build==0.10.0
cellxgene-census==1.2.0
certifi==2022.12.7
cffi==1.15.1
cfgv==3.3.1
charset-normalizer==3.1.0
click==8.1.3
cmake==3.26.3
contourpy==1.0.7
coverage==7.2.7
cryptography==41.0.1
cycler==0.11.0
Cython==0.29.33
decorator==5.1.1
distlib==0.3.6
docutils==0.20.1
exceptiongroup==1.1.1
executing==1.2.0
filelock==3.10.7
fonttools==4.39.3
frozenlist==1.3.3
fsspec==2023.3.0
gitdb==4.0.10
GitPython==3.1.31
h5py==3.8.0
identify==2.5.22
idna==3.4
importlib-metadata==6.6.0
importlib-resources==5.12.0
iniconfig==2.0.0
ipython==8.12.0
jaraco.classes==3.2.3
jedi==0.18.2
jeepney==0.8.0
Jinja2==3.1.2
jmespath==1.0.1
joblib==1.2.0
keyring==23.13.1
kiwisolver==1.4.4
lit==16.0.5
llvmlite==0.39.1
markdown-it-py==2.2.0
MarkupSafe==2.1.2
matplotlib==3.7.1
matplotlib-inline==0.1.6
mdurl==0.1.2
more-itertools==9.1.0
mpmath==1.3.0
multidict==6.0.4
mypy-extensions==1.0.0
natsort==8.3.1
nbqa==1.7.0
networkx==3.0
nodeenv==1.7.0
numba==0.56.4
numpy==1.23.5
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-cupti-cu11==11.7.101
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
nvidia-cufft-cu11==10.9.0.58
nvidia-curand-cu11==10.2.10.91
nvidia-cusolver-cu11==11.4.0.1
nvidia-cusparse-cu11==11.7.4.91
nvidia-nccl-cu11==2.14.3
nvidia-nvtx-cu11==11.7.91
Owlready2==0.40
packaging==23.0
pandas==1.5.3
pandas-stubs==2.0.1.230501
parso==0.8.3
pathspec==0.11.1
patsy==0.5.3
pexpect==4.8.0
pickleshare==0.7.5
Pillow==9.4.0
pkginfo==1.9.6
platformdirs==3.2.0
pluggy==1.0.0
pre-commit==3.3.2
prompt-toolkit==3.0.38
psutil==5.9.4
ptyprocess==0.7.0
pure-eval==0.2.2
pyarrow==11.0.0
pycodestyle==2.10.0
pycparser==2.21
Pygments==2.14.0
pynndescent==0.5.8
pyparsing==3.0.9
pyproject_hooks==1.0.0
pytest==7.3.1
python-dateutil==2.8.2
pytz==2023.3
PyYAML==6.0
readme-renderer==37.3
requests==2.28.2
requests-mock==1.10.0
requests-toolbelt==1.0.0
rfc3986==2.0.0
rich==13.4.1
s3fs==2023.3.0
scanpy==1.9.3
scikit-learn==1.2.2
scikit-misc==0.2.0
scipy==1.10.1
seaborn==0.12.2
SecretStorage==3.3.3
session-info==1.0.0
six==1.16.0
smmap==5.0.0
somacore==1.0.3
stack-data==0.6.2
statsmodels==0.13.5
stdlib-list==0.8.0
sympy==1.12
threadpoolctl==3.1.0
tiledb==0.21.3
tiledbsoma==1.2.5
tokenize-rt==5.0.0
tomli==2.0.1
torch==2.0.1
torchdata==0.6.1
tqdm==4.65.0
traitlets==5.9.0
triton==2.0.0
twine==4.0.2
types-pytz==2023.3.0.0
typing_extensions==4.5.0
umap-learn==0.5.3
urllib3==1.26.15
virtualenv==20.21.0
wcwidth==0.2.6
webencodings==0.5.1
wrapt==1.15.0
yarl==1.8.2
zipp==3.15.0
atolopko-czi commented 1 year ago

All api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py unit tests passed on GHA for Python 3.9, Ubuntu latest (22.04).

Successfully installed MarkupSafe-2.1.3 aiobotocore-2.5.0 aiohttp-3.8.4 aioitertools-0.11.0 aiosignal-1.3.1 anndata-0.9.1 async-timeout-4.0.2 attrs-23.1.0 botocore-1.29.76 cellxgene-census-0.1.dev1+g61d9497 cmake-3.26.3 contourpy-1.0.7 cycler-0.11.0 filelock-3.12.0 fonttools-4.39.4 frozenlist-1.3.3 fsspec-2023.5.0 h5py-3.8.0 importlib-resources-5.12.0 jinja2-3.1.2 jmespath-1.0.1 joblib-1.2.0 kiwisolver-1.4.4 lit-16.0.5.post0 llvmlite-0.39.1 matplotlib-3.7.1 mpmath-1.3.0 multidict-6.0.4 natsort-8.3.1 networkx-3.1 numba-0.56.4 numpy-1.23.5 nvidia-cublas-cu11-11.10.3.66 nvidia-cuda-cupti-cu11-11.7.101 nvidia-cuda-nvrtc-cu11-11.7.99 nvidia-cuda-runtime-cu11-11.7.99 nvidia-cudnn-cu11-8.5.0.96 nvidia-cufft-cu11-10.9.0.58 nvidia-curand-cu11-10.2.10.91 nvidia-cusolver-cu11-11.4.0.1 nvidia-cusparse-cu11-11.7.4.91 nvidia-nccl-cu11-2.14.3 nvidia-nvtx-cu11-11.7.91 pandas-2.0.2 patsy-0.5.3 pillow-9.5.0 pyarrow-12.0.0 pynndescent-0.5.10 pyparsing-3.0.9 python-dateutil-2.8.2 pytz-2023.3 s3fs-2023.5.0 scanpy-1.9.3 scikit-learn-1.2.2 scikit-misc-0.2.0 scipy-1.10.1 seaborn-0.12.2 session-info-1.0.0 somacore-1.0.3 statsmodels-0.14.0 stdlib_list-0.8.0 sympy-1.12 threadpoolctl-3.1.0 tiledb-0.21.4 tiledbsoma-1.2.5 torch-2.0.1 torchdata-0.6.1 tqdm-4.65.0 triton-2.0.0 tzdata-2023.3 umap-learn-0.5.3 urllib3-1.26.16 wrapt-1.15.0 yarl-1.9.2

bkmartinjr commented 1 year ago

I wonder if it is a package version issue? I can configure a venv if you want to provide a spec, and retry.

atolopko-czi commented 1 year ago

I don't consider this a fix, but perturbing how the ExperimentDataPipe object is initialized prior to multiprocessing changes the behavior. Calling dp.shape, which transitively calls the same _init() function as obs_encoders(), allows the test to pass. Alternately, replacing with iter(dp), which also calls the same _init() function, later causes the workers to segfault instead of hanging. In all cases, the code runs up until the DataLoader spawns child processes that then attempt to use the serialized ExperimentDataPipe object There be gremlins herein...

--- a/api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py
+++ b/api/python/cellxgene_census/tests/experimental/ml/test_pytorch.py
@@ -364,7 +364,7 @@ def test_experiment_dataloader__multiprocess_pickling(soma_experiment: Experimen
         obs_column_names=["label"],
     )
     dl = experiment_dataloader(dp, num_workers=2)
-    dp.obs_encoders()  # trigger query building
+    dp.shape  # trigger query building
     row = next(iter(dl))  # trigger multiprocessing
pablo-gar commented 1 year ago

see https://github.com/chanzuckerberg/cellxgene-census/issues/523#issuecomment-1581289586