Closed ayush-1506 closed 3 years ago
This sounds pretty reasonable to me. I don't think there are any current plan to implement this but we would welcome contributions.
See these files for examples of how we wrap LAPACK/CuSolver for new low-level solvers: https://github.com/google/jax/blob/master/jaxlib/lapack.pyx https://github.com/google/jax/blob/master/jaxlib/cusolver.py
(I'm guessing we don't want to implement this method ourselves)
Implemented a wrapper for the required lapack routine, but I'm facing several issues while building jaxlib from source to incorporate these changes. Are jaxlib wheels generated by the CI/CD if I open a PR? That way I can avoid generating them locally.
Tested in ubuntu 18.04 (docker).
No, the github CI uses the _minimum_jaxlib_version
specified in setup.py
and does not cover changes to jaxlib sources. There is some testing involving jaxlib compilation outside github CI once the Pull Ready
tag is added, but that is unfortunately fairly opaque to non-Google employees. The best course would be to figure out how to test your changes locally.
@ayush-1506 But we can probably help you figure out what's going wrong with your jaxlib build. Have you seen https://jax.readthedocs.io/en/latest/developer.html#building-from-source ?
@hawkinsp Yes, I'm following the same instructions. The error seems trivial:
Server terminated abruptly (error code: 14, error message: 'Socket closed', log file: '/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/server/jvm.out')
My guess is that this is happening due to oom.
@ayush-1506 Yes, that is plausible. You might try setting the --jobs
argument to Bazel; using fewer jobs than you have cores might help if you are RAM constrained. You can pass extra Bazel options through the build.py
script using --bazel_options="--jobs 4"
or something like that.
@hawkinsp Thanks, reducing the number of jobs seems to work. However, I ran into another issue after this. (This doesn't seem to be related to oom). I'm building this without rocm/cuda/tpu.
[0 / 17] [Prepa] BazelWorkspaceStatusAction stable-status.txt
[136 / 1,670] Compiling com_google_protobuf/src/google/protobuf/compiler/java/java_message_lite.cc; 1s local ... (2 actions, 1 running)
[348 / 1,871] Compiling com_google_absl/absl/time/internal/cctz/src/time_zone_info.cc; 1s local ... (2 actions, 1 running)
[534 / 1,871] Compiling org_tensorflow/tensorflow/core/platform/hash.cc; 1s local ... (2 actions, 1 running)
[805 / 1,871] Compiling external/org_tensorflow/tensorflow/core/framework/full_type.pb.cc; 0s local ... (2 actions, 1 running)
[933 / 1,871] Compiling org_tensorflow/tensorflow/core/platform/default/posix_file_system.cc; 2s local ... (2 actions, 1 running)
ERROR: /root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/org_tensorflow/tensorflow/core/BUILD:1593:16: C++ compilation of rule '@org_tensorflow//tensorflow/core:framework_internal_impl' failed (Exit 1): gcc failed: error executing command
(cd /root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/execroot/__main__ && \
exec env - \
PATH=/usr/local/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin \
PWD=/proc/self/cwd \
TF_CUDA_COMPUTE_CAPABILITIES=3.5,5.2,6.0,6.1,7.0 \
TF_ROCM_AMDGPU_TARGETS=gfx803,gfx900,gfx906,gfx1010 \
/usr/bin/gcc -U_FORTIFY_SOURCE -fstack-protector -Wall -Wunused-but-set-parameter -Wno-free-nonheap-object -fno-omit-frame-pointer -g0 -O2 '-D_FORTIFY_SOURCE=1' -DNDEBUG -ffunction-sections -fdata-sections '-std=c++11' -MD -MF bazel-out/k8-opt/bin/external/org_tensorflow/tensorflow/core/_objs/framework_internal_impl/batch_util.pic.d '-frandom-seed=bazel-out/k8-opt/bin/external/org_tensorflow/tensorflow/core/_objs/framework_internal_impl/batch_util.pic.o' -fPIC -DHAVE_SYS_UIO_H -DTF_USE_SNAPPY -DEIGEN_MPL2_ONLY '-DEIGEN_MAX_ALIGN_BYTES=64' -iquoteexternal/org_tensorflow -iquotebazel-out/k8-opt/bin/external/org_tensorflow -iquoteexternal/com_google_protobuf -iquotebazel-out/k8-opt/bin/external/com_google_protobuf -iquoteexternal/zlib -iquotebazel-out/k8-opt/bin/external/zlib -iquoteexternal/eigen_archive -iquotebazel-out/k8-opt/bin/external/eigen_archive -iquoteexternal/com_google_absl -iquotebazel-out/k8-opt/bin/external/com_google_absl -iquoteexternal/nsync -iquotebazel-out/k8-opt/bin/external/nsync -iquoteexternal/gif -iquotebazel-out/k8-opt/bin/external/gif -iquoteexternal/libjpeg_turbo -iquotebazel-out/k8-opt/bin/external/libjpeg_turbo -iquoteexternal/com_googlesource_code_re2 -iquotebazel-out/k8-opt/bin/external/com_googlesource_code_re2 -iquoteexternal/farmhash_archive -iquotebazel-out/k8-opt/bin/external/farmhash_archive -iquoteexternal/fft2d -iquotebazel-out/k8-opt/bin/external/fft2d -iquoteexternal/highwayhash -iquotebazel-out/k8-opt/bin/external/highwayhash -iquoteexternal/double_conversion -iquotebazel-out/k8-opt/bin/external/double_conversion -iquoteexternal/snappy -iquotebazel-out/k8-opt/bin/external/snappy -isystem external/com_google_protobuf/src -isystem bazel-out/k8-opt/bin/external/com_google_protobuf/src -isystem external/zlib -isystem bazel-out/k8-opt/bin/external/zlib -isystem external/org_tensorflow/third_party/eigen3/mkl_include -isystem bazel-out/k8-opt/bin/external/org_tensorflow/third_party/eigen3/mkl_include -isystem external/eigen_archive -isystem bazel-out/k8-opt/bin/external/eigen_archive -isystem external/nsync/public -isystem bazel-out/k8-opt/bin/external/nsync/public -isystem external/gif -isystem bazel-out/k8-opt/bin/external/gif -isystem external/farmhash_archive/src -isystem bazel-out/k8-opt/bin/external/farmhash_archive/src -isystem external/double_conversion -isystem bazel-out/k8-opt/bin/external/double_conversion -Wno-sign-compare -Wno-stringop-truncation -mavx '-std=c++14' -DEIGEN_AVOID_STL_ARRAY -Iexternal/gemmlowp -Wno-sign-compare '-ftemplate-depth=900' -fno-exceptions -DINTEL_MKL -msse3 -DTENSORFLOW_MONOLITHIC_BUILD -pthread '-DINTEL_MKL=1' -fno-canonical-system-headers -Wno-builtin-macro-redefined '-D__DATE__="redacted"' '-D__TIMESTAMP__="redacted"' '-D__TIME__="redacted"' -c external/org_tensorflow/tensorflow/core/util/batch_util.cc -o bazel-out/k8-opt/bin/external/org_tensorflow/tensorflow/core/_objs/framework_internal_impl/batch_util.pic.o)
Execution platform: @local_execution_config_platform//:platform
In file included from external/org_tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint:46,
from external/org_tensorflow/tensorflow/core/framework/numeric_types.h:24,
from external/org_tensorflow/tensorflow/core/framework/allocator.h:26,
from external/org_tensorflow/tensorflow/core/framework/tensor.h:23,
from external/org_tensorflow/tensorflow/core/util/batch_util.h:18,
from external/org_tensorflow/tensorflow/core/util/batch_util.cc:16:
external/org_tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX.h:14:41: warning: ignoring attributes on template argument '__m256i' {aka '__vector(4) long long int'} [-Wignored-attributes]
typedef eigen_packet_wrapper<__m256i, 10> Packet32q8i;
^
external/org_tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX.h:15:41: warning: ignoring attributes on template argument '__m128i' {aka '__vector(2) long long int'} [-Wignored-attributes]
typedef eigen_packet_wrapper<__m128i, 11> Packet16q8i;
^
gcc: fatal error: Killed signal terminated program cc1plus
compilation terminated.
Target //build:build_wheel failed to build
INFO: Elapsed time: 1050.784s, Critical Path: 92.65s
INFO: 974 processes: 43 internal, 931 local.
FAILED: Build did NOT complete successfully
ERROR: Build failed. Not running target
FAILED: Build did NOT complete successfully
Traceback (most recent call last):
File "build/build.py", line 521, in <module>
main()
File "build/build.py", line 516, in main
shell(command)
File "build/build.py", line 51, in shell
output = subprocess.check_output(cmd)
File "/usr/local/lib/python3.8/subprocess.py", line 411, in check_output
return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
File "/usr/local/lib/python3.8/subprocess.py", line 512, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['./bazel-3.7.2-linux-x86_64', 'run', '--verbose_failures=true', '--jobs=2', '--config=short_logs', '--config=avx_posix', '--config=mkl_open_source_only', ':build_wheel', '--', '--output_path=/root/jax/dist']' returned non-zero exit status 1.
Wait, I think I figured it out. Generated the wheels successfully. However pip3 install dist/jaxlib-0.1.66-cp36-none-manylinux2010_x86_64.whl
now gives me jaxlib-0.1.66-cp36-none-manylinux2010_x86_64.whl is not a supported wheel on this platform.
.
I'm on x86-64 and python3.6.9
Why could this be happening? @hawkinsp
Ok, never mind that. Everything works now.
I'm nearly done with the lapack wrapper (was away for a couple of weeks). Is CustomCallWithLayout
documented somewhere? Can't find any docs for it. I'm trying to understand what its arguments mean. Can't wrap my head around what shape shape_with_layout
refers to.
@jakevdp @hawkinsp
I don't think there is any documentation for CustomCallWithLayout
; your best bet is probably to take a look at examples within the JAX codebase.
As for shape_with_layout
, this specifies the shape of the outputs. It's essentially a set of abstract shapes that specify the dtype, the dimensions, and the memory layout. e.g. for a 2D array, row-major layout is (0, 1)
and column-major layout is (1, 0)
.
Hope that helps
Hi everyone,
I landed on this issue trying to implement the logm
matrix function for https://github.com/google/jax/issues/5469.
I didn't know if it was still active so went ahead wrote a wrapper for LAPACK's gees routine (which computes the Schur decomposition). It uses C++ templates like the refactored lapack wrappers in jaxlib/lapack_kernels.cc
.
With these additions, on my macOS machine, jaxlib compiles. I wrote a schur primitive (by closely following the eig primitive code) to use the wrapper.
On preliminary tests it outputs the same decomposition as scipy's schur function, so it appears to work.
I am currently working on writing more tests for this decomposition. The current version of my code passes the existing linalg tests so it looks like nothing has broken. I thought I would write before writing the test suite to maybe get some early feedback.
The signatures for the wrappers are here. (For reference, the LAPACK I signatures I followed are here) The implementation of the wrappers is here. The tie-in to scipy's LAPACK is here. The python gees function is here The primitive is here.
I used function pointers for the optional callable argument that gees
uses. Is there a way to pass a callable through CustomCallWithLayout
? Since I didn't know how to refer to a callable in the operands
argument, using the callable argument raises an error for now.
@SaturdayGenfo, have you managed the issue with logm
implementation?
@SaturdayGenfo are you planning to add the sort_eig_vals at some point? It would be very useful for the case when you would like to solve Riccati equations in continuous or discrete-time with jax.
@oleg-kachan no sadly, but some elements needed to write the inverse scaling and squaring algorithm are there. Personally, I don't plan on working on logm
in the very near future (so anyone else is welcome to take a stab at it, one route is to write a custom call to Eigen if you want to side step having to write a jax implementation).
@lenarttreven there is a small obstacle in adding the sorting option and it's that I don't know how to pass a callable function through an XLA custom call (see this specific line). Maybe one option is to have common sorting functions hard coded in and pass an index to refer to them. I'll try to take a look but, it's unlikely to be in the very near future.
Looks like schur decomposition isn't implemented at the moment. Are there any plans to support this? Also, anything specific that might block me from implementing this?