google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.79k stars 2.72k forks source link

Segfault with `jax.linear_transpose` #16027

Closed patrick-kidger closed 1 year ago

patrick-kidger commented 1 year ago

Description

See #15998: prior to its change to jax.linear_transpose, then its test case generated a segfault. (At least on my machine, on the CPU.) This is probably indicative of another bug somewhere.

What jax/jaxlib version are you using?

4.10

Which accelerator(s) are you using?

CPU

Additional system info

No response

NVIDIA GPU info

No response

pschuh commented 1 year ago

I can't reproduce the segfault. When I comment out your fix it just throws an exception.

NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'gt' not implemented

patrick-kidger commented 1 year ago

Hmm, I'm able to get this reliably on two different machines.

Repro:

❯ git clone https://github.com/google/jax --depth=1
❯ cd jax
❯ # edit jax/_src/api.py to remove my fix
❯ pip install -e .[cpu]             
❯ cd tests/
❯ pip install pytest       
❯ pip install absl-py                
❯ pytest api_test.py -k transpose_dce
=========================================================================== test session starts ============================================================================
platform linux -- Python 3.11.3, pytest-7.3.1, pluggy-1.0.0
rootdir: .../projects/jax
configfile: pyproject.toml
collected 768 items / 767 deselected / 1 selected                                                                                                                          

api_test.py Fatal Python error: Segmentation fault

Current thread 0x00007f6c42c49500 (most recent call first):
  File ".../python3.11/site-packages/_pytest/_code/code.py", line 151 in f_locals
  File ".../python3.11/site-packages/_pytest/_code/code.py", line 242 in locals
  File ".../python3.11/site-packages/_pytest/_code/code.py", line 833 in repr_traceback_entry
  File ".../python3.11/site-packages/_pytest/_code/code.py", line 873 in repr_traceback
  File ".../python3.11/site-packages/_pytest/_code/code.py", line 946 in repr_excinfo
  File ".../python3.11/site-packages/_pytest/_code/code.py", line 669 in getrepr
  File ".../python3.11/site-packages/_pytest/nodes.py", line 484 in _repr_failure_py
  File ".../python3.11/site-packages/_pytest/python.py", line 1833 in repr_failure
  File ".../python3.11/site-packages/_pytest/reports.py", line 359 in from_item_and_call
  File ".../python3.11/site-packages/_pytest/runner.py", line 368 in pytest_runtest_makereport
  File ".../python3.11/site-packages/pluggy/_callers.py", line 39 in _multicall
  File ".../python3.11/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File ".../python3.11/site-packages/pluggy/_hooks.py", line 265 in __call__
  File ".../python3.11/site-packages/_pytest/runner.py", line 224 in call_and_report
  File ".../python3.11/site-packages/_pytest/runner.py", line 133 in runtestprotocol
  File ".../python3.11/site-packages/_pytest/runner.py", line 114 in pytest_runtest_protocol
  File ".../python3.11/site-packages/pluggy/_callers.py", line 39 in _multicall
  File ".../python3.11/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File ".../python3.11/site-packages/pluggy/_hooks.py", line 265 in __call__
  File ".../python3.11/site-packages/_pytest/main.py", line 348 in pytest_runtestloop
  File ".../python3.11/site-packages/pluggy/_callers.py", line 39 in _multicall
  File ".../python3.11/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File ".../python3.11/site-packages/pluggy/_hooks.py", line 265 in __call__
  File ".../python3.11/site-packages/_pytest/main.py", line 323 in _main
  File ".../python3.11/site-packages/_pytest/main.py", line 269 in wrap_session
  File ".../python3.11/site-packages/_pytest/main.py", line 316 in pytest_cmdline_main
  File ".../python3.11/site-packages/pluggy/_callers.py", line 39 in _multicall
  File ".../python3.11/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File ".../python3.11/site-packages/pluggy/_hooks.py", line 265 in __call__
  File ".../python3.11/site-packages/_pytest/config/__init__.py", line 166 in main
  File ".../python3.11/site-packages/_pytest/config/__init__.py", line 189 in console_main
  File ".../pytest", line 8 in <module>

Extension modules: numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, jaxlib.cpu_feature_guard (total: 14)
fish: Job 1, 'pytest api_test.py -k transpose…' terminated by signal SIGSEGV (Address boundary error)
pschuh commented 1 year ago

Thanks! The python11 + pytest was important.

Looks like pytest is segfaulting while printing out our tracebacks, but it is the same error.

This is a minimal repro:

from jax._src.lib import xla_client
from absl.testing import absltest
from absl.testing import parameterized

class TracebackTest(parameterized.TestCase):
  def test_basic(self):
    tb = xla_client.Traceback.get_traceback()
    raise RuntimeError("Error").with_traceback(tb.as_python_traceback())

if __name__ == '__main__':
  absltest.main()
hawkinsp commented 1 year ago

I think the problem is using real PyCodeObjects in our fake PyFrameObjects is confusing the Python interpreter. I should have a fix shortly.