jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.46k stars 2.8k forks source link

scipy.linalg.schur decomposition #6478

Closed ayush-1506 closed 3 years ago

ayush-1506 commented 3 years ago

Looks like schur decomposition isn't implemented at the moment. Are there any plans to support this? Also, anything specific that might block me from implementing this?

shoyer commented 3 years ago

This sounds pretty reasonable to me. I don't think there are any current plan to implement this but we would welcome contributions.

See these files for examples of how we wrap LAPACK/CuSolver for new low-level solvers: https://github.com/google/jax/blob/master/jaxlib/lapack.pyx https://github.com/google/jax/blob/master/jaxlib/cusolver.py

(I'm guessing we don't want to implement this method ourselves)

ayush-1506 commented 3 years ago

Implemented a wrapper for the required lapack routine, but I'm facing several issues while building jaxlib from source to incorporate these changes. Are jaxlib wheels generated by the CI/CD if I open a PR? That way I can avoid generating them locally.

Tested in ubuntu 18.04 (docker).

jakevdp commented 3 years ago

No, the github CI uses the _minimum_jaxlib_version specified in setup.py and does not cover changes to jaxlib sources. There is some testing involving jaxlib compilation outside github CI once the Pull Ready tag is added, but that is unfortunately fairly opaque to non-Google employees. The best course would be to figure out how to test your changes locally.

hawkinsp commented 3 years ago

@ayush-1506 But we can probably help you figure out what's going wrong with your jaxlib build. Have you seen https://jax.readthedocs.io/en/latest/developer.html#building-from-source ?

ayush-1506 commented 3 years ago

@hawkinsp Yes, I'm following the same instructions. The error seems trivial:

Server terminated abruptly (error code: 14, error message: 'Socket closed', log file: '/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/server/jvm.out')

My guess is that this is happening due to oom.

hawkinsp commented 3 years ago

@ayush-1506 Yes, that is plausible. You might try setting the --jobs argument to Bazel; using fewer jobs than you have cores might help if you are RAM constrained. You can pass extra Bazel options through the build.py script using --bazel_options="--jobs 4" or something like that.

ayush-1506 commented 3 years ago

@hawkinsp Thanks, reducing the number of jobs seems to work. However, I ran into another issue after this. (This doesn't seem to be related to oom). I'm building this without rocm/cuda/tpu.

[0 / 17] [Prepa] BazelWorkspaceStatusAction stable-status.txt
[136 / 1,670] Compiling com_google_protobuf/src/google/protobuf/compiler/java/java_message_lite.cc; 1s local ... (2 actions, 1 running)
[348 / 1,871] Compiling com_google_absl/absl/time/internal/cctz/src/time_zone_info.cc; 1s local ... (2 actions, 1 running)
[534 / 1,871] Compiling org_tensorflow/tensorflow/core/platform/hash.cc; 1s local ... (2 actions, 1 running)
[805 / 1,871] Compiling external/org_tensorflow/tensorflow/core/framework/full_type.pb.cc; 0s local ... (2 actions, 1 running)
[933 / 1,871] Compiling org_tensorflow/tensorflow/core/platform/default/posix_file_system.cc; 2s local ... (2 actions, 1 running)
ERROR: /root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/org_tensorflow/tensorflow/core/BUILD:1593:16: C++ compilation of rule '@org_tensorflow//tensorflow/core:framework_internal_impl' failed (Exit 1): gcc failed: error executing command
  (cd /root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/execroot/__main__ && \
  exec env - \
    PATH=/usr/local/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin \
    PWD=/proc/self/cwd \
    TF_CUDA_COMPUTE_CAPABILITIES=3.5,5.2,6.0,6.1,7.0 \
    TF_ROCM_AMDGPU_TARGETS=gfx803,gfx900,gfx906,gfx1010 \
  /usr/bin/gcc -U_FORTIFY_SOURCE -fstack-protector -Wall -Wunused-but-set-parameter -Wno-free-nonheap-object -fno-omit-frame-pointer -g0 -O2 '-D_FORTIFY_SOURCE=1' -DNDEBUG -ffunction-sections -fdata-sections '-std=c++11' -MD -MF bazel-out/k8-opt/bin/external/org_tensorflow/tensorflow/core/_objs/framework_internal_impl/batch_util.pic.d '-frandom-seed=bazel-out/k8-opt/bin/external/org_tensorflow/tensorflow/core/_objs/framework_internal_impl/batch_util.pic.o' -fPIC -DHAVE_SYS_UIO_H -DTF_USE_SNAPPY -DEIGEN_MPL2_ONLY '-DEIGEN_MAX_ALIGN_BYTES=64' -iquoteexternal/org_tensorflow -iquotebazel-out/k8-opt/bin/external/org_tensorflow -iquoteexternal/com_google_protobuf -iquotebazel-out/k8-opt/bin/external/com_google_protobuf -iquoteexternal/zlib -iquotebazel-out/k8-opt/bin/external/zlib -iquoteexternal/eigen_archive -iquotebazel-out/k8-opt/bin/external/eigen_archive -iquoteexternal/com_google_absl -iquotebazel-out/k8-opt/bin/external/com_google_absl -iquoteexternal/nsync -iquotebazel-out/k8-opt/bin/external/nsync -iquoteexternal/gif -iquotebazel-out/k8-opt/bin/external/gif -iquoteexternal/libjpeg_turbo -iquotebazel-out/k8-opt/bin/external/libjpeg_turbo -iquoteexternal/com_googlesource_code_re2 -iquotebazel-out/k8-opt/bin/external/com_googlesource_code_re2 -iquoteexternal/farmhash_archive -iquotebazel-out/k8-opt/bin/external/farmhash_archive -iquoteexternal/fft2d -iquotebazel-out/k8-opt/bin/external/fft2d -iquoteexternal/highwayhash -iquotebazel-out/k8-opt/bin/external/highwayhash -iquoteexternal/double_conversion -iquotebazel-out/k8-opt/bin/external/double_conversion -iquoteexternal/snappy -iquotebazel-out/k8-opt/bin/external/snappy -isystem external/com_google_protobuf/src -isystem bazel-out/k8-opt/bin/external/com_google_protobuf/src -isystem external/zlib -isystem bazel-out/k8-opt/bin/external/zlib -isystem external/org_tensorflow/third_party/eigen3/mkl_include -isystem bazel-out/k8-opt/bin/external/org_tensorflow/third_party/eigen3/mkl_include -isystem external/eigen_archive -isystem bazel-out/k8-opt/bin/external/eigen_archive -isystem external/nsync/public -isystem bazel-out/k8-opt/bin/external/nsync/public -isystem external/gif -isystem bazel-out/k8-opt/bin/external/gif -isystem external/farmhash_archive/src -isystem bazel-out/k8-opt/bin/external/farmhash_archive/src -isystem external/double_conversion -isystem bazel-out/k8-opt/bin/external/double_conversion -Wno-sign-compare -Wno-stringop-truncation -mavx '-std=c++14' -DEIGEN_AVOID_STL_ARRAY -Iexternal/gemmlowp -Wno-sign-compare '-ftemplate-depth=900' -fno-exceptions -DINTEL_MKL -msse3 -DTENSORFLOW_MONOLITHIC_BUILD -pthread '-DINTEL_MKL=1' -fno-canonical-system-headers -Wno-builtin-macro-redefined '-D__DATE__="redacted"' '-D__TIMESTAMP__="redacted"' '-D__TIME__="redacted"' -c external/org_tensorflow/tensorflow/core/util/batch_util.cc -o bazel-out/k8-opt/bin/external/org_tensorflow/tensorflow/core/_objs/framework_internal_impl/batch_util.pic.o)
Execution platform: @local_execution_config_platform//:platform
In file included from external/org_tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint:46,
                 from external/org_tensorflow/tensorflow/core/framework/numeric_types.h:24,
                 from external/org_tensorflow/tensorflow/core/framework/allocator.h:26,
                 from external/org_tensorflow/tensorflow/core/framework/tensor.h:23,
                 from external/org_tensorflow/tensorflow/core/util/batch_util.h:18,
                 from external/org_tensorflow/tensorflow/core/util/batch_util.cc:16:
external/org_tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX.h:14:41: warning: ignoring attributes on template argument '__m256i' {aka '__vector(4) long long int'} [-Wignored-attributes]
 typedef eigen_packet_wrapper<__m256i, 10> Packet32q8i;
                                         ^
external/org_tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX.h:15:41: warning: ignoring attributes on template argument '__m128i' {aka '__vector(2) long long int'} [-Wignored-attributes]
 typedef eigen_packet_wrapper<__m128i, 11> Packet16q8i;
                                         ^
gcc: fatal error: Killed signal terminated program cc1plus
compilation terminated.
Target //build:build_wheel failed to build
INFO: Elapsed time: 1050.784s, Critical Path: 92.65s
INFO: 974 processes: 43 internal, 931 local.
FAILED: Build did NOT complete successfully
ERROR: Build failed. Not running target
FAILED: Build did NOT complete successfully
Traceback (most recent call last):
  File "build/build.py", line 521, in <module>
    main()
  File "build/build.py", line 516, in main
    shell(command)
  File "build/build.py", line 51, in shell
    output = subprocess.check_output(cmd)
  File "/usr/local/lib/python3.8/subprocess.py", line 411, in check_output
    return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
  File "/usr/local/lib/python3.8/subprocess.py", line 512, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['./bazel-3.7.2-linux-x86_64', 'run', '--verbose_failures=true', '--jobs=2', '--config=short_logs', '--config=avx_posix', '--config=mkl_open_source_only', ':build_wheel', '--', '--output_path=/root/jax/dist']' returned non-zero exit status 1.
ayush-1506 commented 3 years ago

Wait, I think I figured it out. Generated the wheels successfully. However pip3 install dist/jaxlib-0.1.66-cp36-none-manylinux2010_x86_64.whl now gives me jaxlib-0.1.66-cp36-none-manylinux2010_x86_64.whl is not a supported wheel on this platform.. I'm on x86-64 and python3.6.9

Why could this be happening? @hawkinsp

ayush-1506 commented 3 years ago

Ok, never mind that. Everything works now.

ayush-1506 commented 3 years ago

I'm nearly done with the lapack wrapper (was away for a couple of weeks). Is CustomCallWithLayout documented somewhere? Can't find any docs for it. I'm trying to understand what its arguments mean. Can't wrap my head around what shape shape_with_layout refers to.

@jakevdp @hawkinsp

jakevdp commented 3 years ago

I don't think there is any documentation for CustomCallWithLayout; your best bet is probably to take a look at examples within the JAX codebase.

As for shape_with_layout, this specifies the shape of the outputs. It's essentially a set of abstract shapes that specify the dtype, the dimensions, and the memory layout. e.g. for a 2D array, row-major layout is (0, 1) and column-major layout is (1, 0).

Hope that helps

SaturdayGenfo commented 3 years ago

Hi everyone,

I landed on this issue trying to implement the logm matrix function for https://github.com/google/jax/issues/5469.

I didn't know if it was still active so went ahead wrote a wrapper for LAPACK's gees routine (which computes the Schur decomposition). It uses C++ templates like the refactored lapack wrappers in jaxlib/lapack_kernels.cc.

With these additions, on my macOS machine, jaxlib compiles. I wrote a schur primitive (by closely following the eig primitive code) to use the wrapper.

On preliminary tests it outputs the same decomposition as scipy's schur function, so it appears to work.

I am currently working on writing more tests for this decomposition. The current version of my code passes the existing linalg tests so it looks like nothing has broken. I thought I would write before writing the test suite to maybe get some early feedback.

The signatures for the wrappers are here. (For reference, the LAPACK I signatures I followed are here) The implementation of the wrappers is here. The tie-in to scipy's LAPACK is here. The python gees function is here The primitive is here.

I used function pointers for the optional callable argument that gees uses. Is there a way to pass a callable through CustomCallWithLayout ? Since I didn't know how to refer to a callable in the operands argument, using the callable argument raises an error for now.

oleg-kachan commented 2 years ago

@SaturdayGenfo, have you managed the issue with logm implementation?

lenarttreven commented 2 years ago

@SaturdayGenfo are you planning to add the sort_eig_vals at some point? It would be very useful for the case when you would like to solve Riccati equations in continuous or discrete-time with jax.

SaturdayGenfo commented 2 years ago

@oleg-kachan no sadly, but some elements needed to write the inverse scaling and squaring algorithm are there. Personally, I don't plan on working on logm in the very near future (so anyone else is welcome to take a stab at it, one route is to write a custom call to Eigen if you want to side step having to write a jax implementation).

@lenarttreven there is a small obstacle in adding the sorting option and it's that I don't know how to pass a callable function through an XLA custom call (see this specific line). Maybe one option is to have common sorting functions hard coded in and pass an index to refer to them. I'll try to take a look but, it's unlikely to be in the very near future.