Closed franciscovargas closed 4 years ago
It's going to be hard to debug this without being able to reproduce it. (It does not reproduce for me.)
Are you able to get a backtrace with gdb
and a disassembly around the faulting instruction? (i.e., start Python, open gdb
, attach to the Python process with (attach <pid>
), trigger the failure, and then show the output of
bt
and
dis
around the faulting instruction.)
That might give us a clue as to what went wrong.
Triggered failure:
[New Thread 0x7f9875feb700 (LWP 21508)]
Thread 18 "python3" received signal SIGILL, Illegal instruction.
[Switching to Thread 0x7f987d7fa700 (LWP 21462)]
0x00007f98ae2487ab in jit.normal ()
bt output:
(gdb) bt
#0 0x00007f98ae2487ab in jit.normal ()
#1 0x00007f989277807e in xla::cpu::CpuExecutable::ExecuteComputeFunction(xla::ExecutableRunOptions const*, absl::lts_2020_02_25::Span<stream_executor::DeviceMemoryBase const>, xla::HloExecutionProfile*) () from /auto/homes/fav25/jax/build/jaxlib/xla_extension.so
#2 0x00007f9892778c3d in xla::cpu::CpuExecutable::ExecuteAsyncOnStream(xla::ServiceExecutableRunOptions const*, std::vector<xla::ExecutionInput, std::allocator<xla::ExecutionInput> >, xla::HloExecutionProfile*)::AsyncRunTask::operator()() ()
from /auto/homes/fav25/jax/build/jaxlib/xla_extension.so
#3 0x00007f98948edfec in stream_executor::host::HostStream::WorkLoop() ()
from /auto/homes/fav25/jax/build/jaxlib/xla_extension.so
#4 0x00007f989202b6df in ?? () from /usr/lib/x86_64-linux-gnu/libstdc++.so.6
#5 0x00007f98af3bb6db in start_thread (arg=0x7f987d7fa700) at pthread_create.c:463
#6 0x00007f98af6f488f in clone () at ../sysdeps/unix/sysv/linux/x86_64/clone.S:95
disassemble output on current frame (jit.normal()):
(gdb) disassemble
Dump of assembler code for function jit__normal.260:
0x00007f9414750000 <+0>: push %rbp
0x00007f9414750001 <+1>: push %r15
0x00007f9414750003 <+3>: push %r14
0x00007f9414750005 <+5>: push %r13
0x00007f9414750007 <+7>: push %r12
0x00007f9414750009 <+9>: push %rbx
0x00007f941475000a <+10>: sub $0xa8,%rsp
0x00007f9414750011 <+17>: mov 0x28(%rcx),%rax
0x00007f9414750015 <+21>: mov 0x40(%rcx),%r10
0x00007f9414750019 <+25>: movabs $0x7f941474f000,%rsi
0x00007f9414750023 <+35>: vmovups (%rsi),%xmm0
0x00007f9414750027 <+39>: lea 0x60(%r10),%r8
0x00007f941475002b <+43>: vmovups %xmm0,0x60(%r10)
0x00007f9414750031 <+49>: movabs $0x7f941474f010,%rsi
0x00007f941475003b <+59>: vmovups (%rsi),%xmm0
0x00007f941475003f <+63>: lea 0x70(%r10),%r9
0x00007f9414750043 <+67>: vmovups %xmm0,0x70(%r10)
0x00007f9414750049 <+73>: lea 0x80(%r10),%rdx
0x00007f9414750050 <+80>: mov (%rax),%ebx
0x00007f9414750052 <+82>: mov 0x4(%rax),%eax
0x00007f9414750055 <+85>: lea 0x2(%rax),%esi
0x00007f9414750058 <+88>: mov %esi,0x80(%r10)
0x00007f941475005f <+95>: mov %eax,0x84(%r10)
0x00007f9414750066 <+102>: lea 0x90(%r10),%rdi
0x00007f941475006d <+109>: mov %ebx,0x90(%r10)
0x00007f9414750074 <+116>: lea 0x1(%rbx),%esi
0x00007f9414750077 <+119>: mov %esi,0x94(%r10)
0x00007f941475007e <+126>: lea 0xe0(%r10),%rbp
0x00007f9414750085 <+133>: movl $0x0,0xe0(%r10)
0x00007f9414750090 <+144>: lea 0xa0(%r10),%r11
0x00007f9414750097 <+151>: mov %ebx,%esi
0x00007f9414750099 <+153>: xor %eax,%esi
0x00007f941475009b <+155>: xor $0x1bd11bda,%esi
0x00007f94147500a1 <+161>: mov %esi,0xa0(%r10)
0x00007f94147500a8 <+168>: lea 0xc0(%r10),%rsi
0x00007f94147500af <+175>: mov %eax,0xc0(%r10)
0x00007f94147500b6 <+182>: lea 0xd0(%r10),%rax
0x00007f94147500bd <+189>: mov %ebx,0xd0(%r10)
0x00007f94147500c4 <+196>: mov %rbp,-0x40(%rsp)
0x00007f94147500c9 <+201>: mov %rbp,(%r10)
0x00007f94147500cc <+204>: mov %rdi,0x8(%r10)
0x00007f94147500d0 <+208>: mov %rdx,-0x38(%rsp)
0x00007f94147500d5 <+213>: mov %rdx,0x10(%r10)
0x00007f94147500d9 <+217>: mov %rsi,(%rsp)
Not sure which one of those instructions to switch to?
Since I am not sure which instruction to show the dump for I'll go through all 5 here (missing # 4 but it threw an error for that one not being able to find a function with that address):
Dump of assembler code for function _ZN3xla3cpu13CpuExecutable22ExecuteComputeFunctionEPKNS_20ExecutableRunOptionsEN4absl14lts_2020_02_254SpanIKN15stream_executor16DeviceMemoryBaseEEEPNS_19HloExecutionProfileE:
0x00007fdfea14bee0 <+0>: push %rbp
0x00007fdfea14bee1 <+1>: mov %rsp,%rbp
0x00007fdfea14bee4 <+4>: push %r15
0x00007fdfea14bee6 <+6>: push %r14
0x00007fdfea14bee8 <+8>: push %r13
0x00007fdfea14beea <+10>: push %r12
0x00007fdfea14beec <+12>: mov %r9,%r14
0x00007fdfea14beef <+15>: push %rbx
0x00007fdfea14bef0 <+16>: mov %rsi,%r15
0x00007fdfea14bef3 <+19>: mov %rdx,%r13
0x00007fdfea14bef6 <+22>: mov %rcx,%rbx
0x00007fdfea14bef9 <+25>: mov %r8,%r12
0x00007fdfea14befc <+28>: sub $0x2f8,%rsp
0x00007fdfea14bf03 <+35>: mov %rdi,-0x2d8(%rbp)
0x00007fdfea14bf0a <+42>: mov %fs:0x28,%rax
0x00007fdfea14bf13 <+51>: mov %rax,-0x38(%rbp)
0x00007fdfea14bf17 <+55>: xor %eax,%eax
0x00007fdfea14bf19 <+57>: callq 0x7fdfec6f7e30 <_ZN10tensorflow3Env7DefaultEv>
0x00007fdfea14bf1e <+62>: mov (%rax),%rdx
0x00007fdfea14bf21 <+65>: mov %rax,%rdi
0x00007fdfea14bf24 <+68>: callq *0x50(%rdx)
0x00007fdfea14bf27 <+71>: test %r14,%r14
0x00007fdfea14bf2a <+74>: mov %rax,-0x2f8(%rbp)
0x00007fdfea14bf31 <+81>: je 0x7fdfea14c230 <_ZN3xla3cpu13CpuExecutable22ExecuteComputeFunctionEPKNS_20ExecutableRunOptionsEN4absl14lts_2020_02_254SpanIKN15stream_executor16DeviceMemoryBaseEEEPNS_19HloExecutionProfileE+848>
0x00007fdfea14bf37 <+87>: mov 0x10(%r14),%rax
0x00007fdfea14bf3b <+91>: mov %rax,%rcx
0x00007fdfea14bf3e <+94>: mov %rax,-0x2e0(%rbp)
0x00007fdfea14bf45 <+101>: mov 0x18(%r14),%rax
0x00007fdfea14bf49 <+105>: sub %rcx,%rax
0x00007fdfea14bf4c <+108>: sar $0x3,%rax
0x00007fdfea14bf50 <+112>: mov %rax,-0x300(%rbp)
0x00007fdfea14bf57 <+119>: lea (%r12,%r12,2),%rax
0x00007fdfea14bf5b <+123>: movq $0x0,-0x290(%rbp)
0x00007fdfea14bf66 <+134>: movq $0x0,-0x288(%rbp)
0x00007fdfea14bf71 <+145>: movq $0x0,-0x280(%rbp)
0x00007fdfea14bf7c <+156>: lea (%rbx,%rax,8),%r12
0x00007fdfea14bf80 <+160>: cmp %r12,%rbx
0x00007fdfea14bf83 <+163>: je 0x7fdfea14c010 <_ZN3xla3cpu13CpuExecutable22ExecuteComputeFunctionEPKNS_20ExecutableRunOptionsEN4absl14lts_2020_02_254SpanIKN15stream_executor16DeviceMemoryBaseEEEPNS_19HloExecutionProfileE+304>
0x00007fdfea14bf89 <+169>: lea -0x250(%rbp),%rax
0x00007fdfea14bf90 <+176>: xor %edx,%edx
0x00007fdfea14bf92 <+178>: xor %esi,%esi
Dump of assembler code for function _ZZN3xla3cpu13CpuExecutable20ExecuteAsyncOnStreamEPKNS_27ServiceExecutableRunOptionsESt6vectorINS_14ExecutionInputESaIS6_EEPNS_19HloExecutionProfileEEN12AsyncRunTaskclEv:
0x00007fdfea14cbe0 <+0>: push %rbp
0x00007fdfea14cbe1 <+1>: lea 0x8(%rdi),%rdx
0x00007fdfea14cbe5 <+5>: mov %rsp,%rbp
0x00007fdfea14cbe8 <+8>: push %r12
0x00007fdfea14cbea <+10>: push %rbx
0x00007fdfea14cbeb <+11>: lea -0x1a8(%rbp),%rbx
0x00007fdfea14cbf2 <+18>: sub $0x1a0,%rsp
0x00007fdfea14cbf9 <+25>: mov 0x80(%rdi),%rcx
0x00007fdfea14cc00 <+32>: mov 0x88(%rdi),%r8
0x00007fdfea14cc07 <+39>: mov %fs:0x28,%rax
0x00007fdfea14cc10 <+48>: mov %rax,-0x18(%rbp)
0x00007fdfea14cc14 <+52>: xor %eax,%eax
0x00007fdfea14cc16 <+54>: movabs $0xaaaaaaaaaaaaaaab,%rax
0x00007fdfea14cc20 <+64>: mov (%rdi),%rsi
0x00007fdfea14cc23 <+67>: mov 0xa8(%rdi),%r9
0x00007fdfea14cc2a <+74>: sub %rcx,%r8
0x00007fdfea14cc2d <+77>: mov %rbx,%rdi
0x00007fdfea14cc30 <+80>: sar $0x3,%r8
0x00007fdfea14cc34 <+84>: imul %rax,%r8
0x00007fdfea14cc38 <+88>: callq 0x7fdfea14bee0 <_ZN3xla3cpu13CpuExecutable22ExecuteComputeFunctionEPKNS_20ExecutableRunOptionsEN4absl14lts_2020_02_254SpanIKN15stream_executor16DeviceMemoryBaseEEEPNS_19HloExecutionProfileE>
0x00007fdfea14cc3d <+93>: cmpq $0x0,-0x1a8(%rbp)
0x00007fdfea14cc45 <+101>: je 0x7fdfea14cc6f <_ZZN3xla3cpu13CpuExecutable20ExecuteAsyncOnStreamEPKNS_27ServiceExecutableRunOptionsESt6vectorINS_14ExecutionInputESaIS6_EEPNS_19HloExecutionProfileEEN12AsyncRunTaskclEv+143>
0x00007fdfea14cc47 <+103>: lea 0x28c248a(%rip),%rsi # 0x7fdfeca0f0d8
0x00007fdfea14cc4e <+110>: mov %rbx,%rdi
0x00007fdfea14cc51 <+113>: callq 0x7fdfec7403c0 <_ZN10tensorflow24TfCheckOpHelperOutOfLineB5cxx11ERKNS_6StatusEPKc>
0x00007fdfea14cc56 <+118>: mov -0x1a8(%rbp),%rdi
0x00007fdfea14cc5d <+125>: mov %rax,%rbx
0x00007fdfea14cc60 <+128>: test %rdi,%rdi
0x00007fdfea14cc63 <+131>: je 0x7fdfea14cc6a <_ZZN3xla3cpu13CpuExecutable20ExecuteAsyncOnStreamEPKNS_27ServiceExecutableRunOptionsESt6vectorINS_14ExecutionInputESaIS6_EEPNS_19HloExecutionProfileEEN12AsyncRunTaskclEv+138>
0x00007fdfea14cc65 <+133>: callq 0x7fdfea147680 <_ZNKSt14default_deleteIN10tensorflow6Status5StateEEclEPS2_.isra.234>
0x00007fdfea14cc6a <+138>: test %rbx,%rbx
0x00007fdfea14cc6d <+141>: jne 0x7fdfea14cc8f <_ZZN3xla3cpu13CpuExecutable20ExecuteAsyncOnStreamEPKNS_27ServiceExecutableRunOptionsESt6vectorINS_14ExecutionInputESaIS6_EEPNS_19HloExecutionProfileEEN12AsyncRunTaskclEv+175>
0x00007fdfea14cc6f <+143>: mov -0x18(%rbp),%rax
0x00007fdfea14cc73 <+147>: xor %fs:0x28,%rax
0x00007fdfea14cc7c <+156>: jne 0x7fdfea14cc8a <_ZZN3xla3cpu13CpuExecutable20ExecuteAsyncOnStreamEPKNS_27ServiceExecutableRunOptionsESt6vectorINS_14ExecutionInputESaIS6_EEPNS_19HloExecutionProfileEEN12AsyncRunTaskclEv+170>
0x00007fdfea14cc7e <+158>: add $0x1a0,%rsp
0x00007fdfea14cc85 <+165>: pop %rbx
0x00007fdfea14cc86 <+166>: pop %r12
disassemble 0x00007fdfec2c1fec
Dump of assembler code for function _ZN15stream_executor4host10HostStream8WorkLoopEv:
0x00007fdfec2c1e90 <+0>: push %rbp
0x00007fdfec2c1e91 <+1>: mov %rsp,%rbp
0x00007fdfec2c1e94 <+4>: push %r15
0x00007fdfec2c1e96 <+6>: push %r14
0x00007fdfec2c1e98 <+8>: push %r13
0x00007fdfec2c1e9a <+10>: push %r12
0x00007fdfec2c1e9c <+12>: push %rbx
0x00007fdfec2c1e9d <+13>: mov %rdi,%rbx
0x00007fdfec2c1ea0 <+16>: sub $0xa8,%rsp
0x00007fdfec2c1ea7 <+23>: mov %fs:0x28,%rax
0x00007fdfec2c1eb0 <+32>: mov %rax,-0x38(%rbp)
0x00007fdfec2c1eb4 <+36>: xor %eax,%eax
0x00007fdfec2c1eb6 <+38>: lea -0xb6(%rbp),%rax
0x00007fdfec2c1ebd <+45>: mov %rax,%rdi
0x00007fdfec2c1ec0 <+48>: mov %rax,-0xc8(%rbp)
0x00007fdfec2c1ec7 <+55>: callq 0x7fdfec705010 <_ZN10tensorflow4port19ScopedFlushDenormalC2Ev>
0x00007fdfec2c1ecc <+60>: lea -0xb4(%rbp),%rax
0x00007fdfec2c1ed3 <+67>: xor %esi,%esi
0x00007fdfec2c1ed5 <+69>: mov %rax,%rdi
0x00007fdfec2c1ed8 <+72>: mov %rax,-0xd0(%rbp)
0x00007fdfec2c1edf <+79>: callq 0x7fdfec73e4a0 <_ZN10tensorflow4port14ScopedSetRoundC2Ei>
0x00007fdfec2c1ee4 <+84>: lea 0x8(%rbx),%r12
0x00007fdfec2c1ee8 <+88>: lea -0xb0(%rbp),%r15
0x00007fdfec2c1eef <+95>: lea -0x60(%rbp),%r14
0x00007fdfec2c1ef3 <+99>: lea -0x80(%rbp),%r13
0x00007fdfec2c1ef7 <+103>: mov %r12,%rdi
0x00007fdfec2c1efa <+106>: movq $0x0,-0x70(%rbp)
0x00007fdfec2c1f02 <+114>: callq 0x7fdfec73c270 <_ZN4absl14lts_2020_02_255Mutex4LockEv>
0x00007fdfec2c1f07 <+119>: lea -0xfe(%rip),%rax # 0x7fdfec2c1e10 <_ZN4absl14lts_2020_02_259Condition17CastAndCallMethodIN15stream_executor4host10HostStreamEEEbPKS1_>
0x00007fdfec2c1f0e <+126>: mov %r15,%rsi
0x00007fdfec2c1f11 <+129>: mov %r12,%rdi
0x00007fdfec2c1f14 <+132>: movq $0x0,-0xa8(%rbp)
0x00007fdfec2c1f1f <+143>: movq $0x0,-0x98(%rbp)
0x00007fdfec2c1f2a <+154>: mov %rbx,-0x90(%rbp)
0x00007fdfec2c1f31 <+161>: mov %rax,-0xb0(%rbp)
0x00007fdfec2c1f38 <+168>: lea -0x18f(%rip),%rax # 0x7fdfec2c1db0 <_ZN15stream_executor4host10HostStream13WorkAvailableEv>
0x00007fdfec2c1f3f <+175>: mov %rax,-0xa0(%rbp)
0x00007fdfec2c1f46 <+182>: callq 0x7fdfec73bfb0 <_ZN4absl14lts_2020_02_255Mutex5AwaitERKNS0_9ConditionE>
0x00007fdfec2c1f4b <+187>: mov 0x20(%rbx),%rax
0x00007fdfec2c1f4f <+191>: movdqa -0x60(%rbp),%xmm1
0x00007fdfec2c1f54 <+196>: movdqu (%rax),%xmm0
0x00007fdfec2c1f58 <+200>: mov 0x10(%rax),%rcx
0x00007fdfec2c1f5c <+204>: mov 0x18(%rax),%rdx
Dump of assembler code for function start_thread:
0x00007fe006d8f600 <+0>: push %r14
0x00007fe006d8f602 <+2>: push %rbx
0x00007fe006d8f603 <+3>: mov %rdi,%rbx
0x00007fe006d8f606 <+6>: sub $0xa8,%rsp
0x00007fe006d8f60d <+13>: mov %rdi,0x8(%rsp)
0x00007fe006d8f612 <+18>: mov %fs:0x28,%rax
0x00007fe006d8f61b <+27>: mov %rax,0x98(%rsp)
0x00007fe006d8f623 <+35>: xor %eax,%eax
0x00007fe006d8f625 <+37>: rdtsc
0x00007fe006d8f627 <+39>: shl $0x20,%rdx
0x00007fe006d8f62b <+43>: mov %eax,%eax
0x00007fe006d8f62d <+45>: or %rax,%rdx
0x00007fe006d8f630 <+48>: mov %rdx,%fs:0x620
0x00007fe006d8f639 <+57>: mov 0x212968(%rip),%rax # 0x7fe006fa1fa8
0x00007fe006d8f640 <+64>: lea 0x6b8(%rdi),%rdx
0x00007fe006d8f647 <+71>: mov %rdx,%fs:(%rax)
0x00007fe006d8f64b <+75>: callq 0x7fe006d8da10 <__ctype_init@plt>
0x00007fe006d8f650 <+80>: xor %eax,%eax
0x00007fe006d8f652 <+82>: xchg %eax,0x61c(%rbx)
0x00007fe006d8f658 <+88>: cmp $0xfffffffe,%eax
0x00007fe006d8f65b <+91>: je 0x7fe006d8f7cc <start_thread+460>
0x00007fe006d8f661 <+97>: mov 0x8(%rsp),%rbx
0x00007fe006d8f666 <+102>: mov $0x18,%esi
0x00007fe006d8f66b <+107>: mov $0x111,%eax
0x00007fe006d8f670 <+112>: lea 0x2e0(%rbx),%rdi
0x00007fe006d8f677 <+119>: syscall
0x00007fe006d8f679 <+121>: testb $0x4,0x614(%rbx)
0x00007fe006d8f680 <+128>: jne 0x7fe006d8f798 <start_thread+408>
0x00007fe006d8f686 <+134>: lea 0x10(%rsp),%rdi
0x00007fe006d8f68b <+139>: movq $0x0,0x58(%rsp)
0x00007fe006d8f694 <+148>: movq $0x0,0x60(%rsp)
0x00007fe006d8f69d <+157>: callq 0x7fe006d8d8d0 <_setjmp@plt>
0x00007fe006d8f6a2 <+162>: test %eax,%eax
0x00007fe006d8f6a4 <+164>: mov %eax,%ebx
0x00007fe006d8f6a6 <+166>: jne 0x7fe006d8f6e4 <start_thread+228>
0x00007fe006d8f6a8 <+168>: lea 0x10(%rsp),%rax
0x00007fe006d8f6ad <+173>: mov %rax,%fs:0x300
0x00007fe006d8f6b6 <+182>: mov 0x8(%rsp),%rax
0x00007fe006d8f6bb <+187>: cmpb $0x0,0x613(%rax)
0x00007fe006d8f6c2 <+194>: jne 0x7fe006d8f820 <start_thread+544>
0x00007fe006d8f6c8 <+200>: mov 0x8(%rsp),%rax
0x00007fe006d8f6cd <+205>: nop
0x00007fe006d8f6ce <+206>: mov 0x648(%rax),%rdi
0x00007fe006d8f6d5 <+213>: callq *0x640(%rax)
Dump of assembler code for function clone:
0x00007fe0070c8850 <+0>: mov $0xffffffffffffffea,%rax
0x00007fe0070c8857 <+7>: test %rdi,%rdi
0x00007fe0070c885a <+10>: je 0x7fe0070c8899 <clone+73>
0x00007fe0070c885c <+12>: test %rsi,%rsi
0x00007fe0070c885f <+15>: je 0x7fe0070c8899 <clone+73>
0x00007fe0070c8861 <+17>: sub $0x10,%rsi
0x00007fe0070c8865 <+21>: mov %rcx,0x8(%rsi)
0x00007fe0070c8869 <+25>: mov %rdi,(%rsi)
0x00007fe0070c886c <+28>: mov %rdx,%rdi
0x00007fe0070c886f <+31>: mov %r8,%rdx
0x00007fe0070c8872 <+34>: mov %r9,%r8
0x00007fe0070c8875 <+37>: mov 0x8(%rsp),%r10
0x00007fe0070c887a <+42>: mov $0x38,%eax
0x00007fe0070c887f <+47>: syscall
0x00007fe0070c8881 <+49>: test %rax,%rax
0x00007fe0070c8884 <+52>: jl 0x7fe0070c8899 <clone+73>
0x00007fe0070c8886 <+54>: je 0x7fe0070c8889 <clone+57>
0x00007fe0070c8888 <+56>: retq
0x00007fe0070c8889 <+57>: xor %ebp,%ebp
0x00007fe0070c888b <+59>: pop %rax
0x00007fe0070c888c <+60>: pop %rdi
0x00007fe0070c888d <+61>: callq *%rax
0x00007fe0070c888f <+63>: mov %rax,%rdi
0x00007fe0070c8892 <+66>: mov $0x3c,%eax
0x00007fe0070c8897 <+71>: syscall
0x00007fe0070c8899 <+73>: mov 0x2c95c8(%rip),%rcx # 0x7fe007391e68
0x00007fe0070c88a0 <+80>: neg %eax
0x00007fe0070c88a2 <+82>: mov %eax,%fs:(%rcx)
0x00007fe0070c88a5 <+85>: or $0xffffffffffffffff,%rax
0x00007fe0070c88a9 <+89>: retq
Can you also verify this happens with the current released versions of jax
and jaxlib
, not a self-built jaxlib
?
What I really need is a dissassembly at the faulting address, which is inside JIT-compiled code in stack frame 0
. Apparently dis
won't understand JIT-compiled code by default. You might try something like:
disassemble $rip-50, +0x100
(My big issue debugging this is I can't reproduce it.)
(My big issue debugging this is I can't reproduce it.)
I completely understand, thanks so much for still trying to debug it. Do you think there might be any further information I can give to attempt to reproduce it ?
Can you also verify this happens with the current released versions of ...
$ pip3 install --upgrade pip
$ pip3 install --upgrade jax jaxlib
Outcome is the same:
>>> import jax
>>> key = jax.random.PRNGKey(0)
/home/fav25/.local/lib/python3.6/site-packages/jax/lib/xla_bridge.py:116: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
>>> jax.random.normal(key, shape=(3,))
Illegal instruction (core dumped)
What I really need is a dissassembly at the faulting address
(gdb) disassemble $rip-50, +0x100 Dump of assembler code from 0x7f6ca4a89779 to 0x7f6ca4a89879: 0x00007f6ca4a89779 <jit__normal.260+1913>: cli 0x00007f6ca4a8977a <jit__normal.260+1914>: and %cl,-0x3bffb28c(%rbp) 0x00007f6ca4a89780 <jit__normal.260+1920>: loop 0x7f6ca4a897ed <jit__normal.260+2029> 0x00007f6ca4a89782 <jit__normal.260+1922>: test $0xcb430f41,%ebp 0x00007f6ca4a89788 <jit__normal.260+1928>: or %eax,%ecx 0x00007f6ca4a8978a <jit__normal.260+1930>: jmpq 0x7f6ca4a891de <jit__normal.260+478> 0x00007f6ca4a8978f <jit__normal.260+1935>: mov 0x8(%r10),%rcx 0x00007f6ca4a89793 <jit__normal.260+1939>: mov (%rcx),%eax 0x00007f6ca4a89795 <jit__normal.260+1941>: shr $0x9,%eax 0x00007f6ca4a89798 <jit__normal.260+1944>: or $0x3f800000,%eax 0x00007f6ca4a8979d <jit__normal.260+1949>: vmovd %eax,%xmm0 0x00007f6ca4a897a1 <jit__normal.260+1953>: movabs $0x7f6ca621c020,%rax => 0x00007f6ca4a897ab <jit__normal.260+1963>: vmovss (%rax),%xmm24 0x00007f6ca4a897b1 <jit__normal.260+1969>: movabs $0x7f6ca621c024,%rax 0x00007f6ca4a897bb <jit__normal.260+1979>: vmovss (%rax),%xmm1 0x00007f6ca4a897bf <jit__normal.260+1983>: vfmadd213ss %xmm24,%xmm1,%xmm0 0x00007f6ca4a897c5 <jit__normal.260+1989>: vmovaps %xmm1,%xmm19 0x00007f6ca4a897cb <jit__normal.260+1995>: vmovss %xmm1,-0x18(%rsp) 0x00007f6ca4a897d1 <jit__normal.260+2001>: movabs $0x7f6ca621c028,%rax 0x00007f6ca4a897db <jit__normal.260+2011>: vmovss (%rax),%xmm2 0x00007f6ca4a897df <jit__normal.260+2015>: movabs $0x7f6ca621c02c,%rax 0x00007f6ca4a897e9 <jit__normal.260+2025>: vbroadcastss (%rax),%xmm26 0x00007f6ca4a897ef <jit__normal.260+2031>: vmaxss %xmm0,%xmm2,%xmm22 0x00007f6ca4a897f5 <jit__normal.260+2037>: vmovaps %xmm2,%xmm27 0x00007f6ca4a897fb <jit__normal.260+2043>: vmovss %xmm2,-0x14(%rsp) 0x00007f6ca4a89801 <jit__normal.260+2049>: movabs $0x7f6ca621c030,%rax 0x00007f6ca4a8980b <jit__normal.260+2059>: vbroadcastss (%rax),%xmm1 0x00007f6ca4a89810 <jit__normal.260+2064>: vxorps %xmm1,%xmm22,%xmm0 0x00007f6ca4a89816 <jit__normal.260+2070>: vmovaps %xmm1,%xmm7 0x00007f6ca4a8981a <jit__normal.260+2074>: vmulss %xmm0,%xmm22,%xmm8 0x00007f6ca4a89820 <jit__normal.260+2080>: movabs $0x7f6ca621c034,%rax 0x00007f6ca4a8982a <jit__normal.260+2090>: vmovss (%rax),%xmm14 0x00007f6ca4a8982e <jit__normal.260+2094>: vaddss %xmm14,%xmm8,%xmm0 0x00007f6ca4a89833 <jit__normal.260+2099>: vxorps %xmm2,%xmm2,%xmm2 0x00007f6ca4a89837 <jit__normal.260+2103>: xor %esi,%esi 0x00007f6ca4a89839 <jit__normal.260+2105>: vucomiss %xmm2,%xmm0 0x00007f6ca4a8983d <jit__normal.260+2109>: setbe %sil 0x00007f6ca4a89841 <jit__normal.260+2113>: neg %esi 0x00007f6ca4a89843 <jit__normal.260+2115>: vcmpeqss %xmm2,%xmm0,%k0 0x00007f6ca4a8984a <jit__normal.260+2122>: kmovw %k0,%edi 0x00007f6ca4a8984e <jit__normal.260+2126>: neg %edi 0x00007f6ca4a89850 <jit__normal.260+2128>: movabs $0x7f6ca621c038,%rax 0x00007f6ca4a8985a <jit__normal.260+2138>: vmovss (%rax),%xmm1 0x00007f6ca4a8985e <jit__normal.260+2142>: xor %ebp,%ebp
Ok, that's interesting. I note that vmovss (%rax),%xmm24
is an AVX512 instruction, which are relatively new (Intel Skylake or newer). It looks perfectly legal to me.
I have two hypotheses:
a) the access is unaligned. We could determine if that's the case if you could also dump the register values via info registers
. I'm particularly interested in the value of rax
at the faulting instruction. If it's not a multiple of 16, that's the problem.
b) you have a virtualization issue. I note that lscpu
shows that you are running under a Xen hypervisor. Is it possible that the version of Xen you are using does not support AVX512? Have you run other AVX512 code successfully?
Did lscpu
show any CPU feature flags? Normally there's a flags:
line at the bottom; if it's not there, you might also look at the contents of /proc/cpuinfo
.
One way we could confirm this is an AVX512 problem would be for you to locally rebuild jaxlib with AVX512 support disabled. To do this, do the following:
jax
git checkout directory.git checkout https://github.com/tensorflow/tensorflow.git
)org_tensorflow
rule.jaxlib
.Thanks for bearing with us; we need you to do some of the heavy lifting because we can't reproduce this!
Feature flags (no hits for avx5.*, theres an avx2 flag):
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush acpi mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single pti intel_ppin ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx rdseed adx smap clflushopt clwb xsaveopt xsavec xgetbv1 xsaves pku ospke md_clear flush_l1d
One way we could confirm this is an AVX512 problem would be for you to locally rebuild jaxlib with AVX512 support disabled.
Attempting this now.
This is after triggering the event. Should I have run a command before it ? like moving to a specific stack frame ?
info registers
rax 139778147307552
rbx 0x2ff9780 50304896
rcx 0x3016550 50423120
rdx 0x6 6
rsi 0x3016530 50423088
rdi 0x3016550 50423120
rbp 0x3016500 0x3016500
rsp 0x7f20707f78a0 0x7f20707f78a0
r8 0x1bd11bda 466688986
r9 0x13 19
r10 0x30164c0 50422976
r11 0x0 0
r12 0xd 13
r13 0x0 0
r14 0x11 17
r15 0xf 15
rip 0x7f20a2d167ab 0x7f20a2d167ab <jit.normal+1963>
eflags 0x10202 [ IF RF ]
cs 0x33 51
ss 0x2b 43
ds 0x0 0
es 0x0 0
fs 0x0 0
gs 0x0 0
Getting a crash, this is what I changed WORKSPACE to
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
http_archive(
name = "io_bazel_rules_closure",
sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9",
strip_prefix = "rules_closure-308b05b2419edb5c8ee0471b67a40403df940149",
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz",
"https://github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", # 2019-06-13
],
)
# https://github.com/bazelbuild/bazel-skylib/releases
http_archive(
name = "bazel_skylib",
sha256 = "1dde365491125a3db70731e25658dfdd3bc5dbdfd11b840b3e987ecf043c7ca0",
urls = [
"http://mirror.tensorflow.org/github.com/bazelbuild/bazel-skylib/releases/download/0.9.0/bazel_skylib-0.9.0.tar.gz",
"https://github.com/bazelbuild/bazel-skylib/releases/download/0.9.0/bazel_skylib-0.9.0.tar.gz",
],
)
# To update TensorFlow to a new revision,
# a) update URL and strip_prefix to the new git commit hash
# b) get the sha256 hash of the commit by running:
# curl -L https://github.com/tensorflow/tensorflow/archive/<git hash>.tar.gz | sha256sum
# and update the sha256 with the result.
#http_archive(
# name = "org_tensorflow",
# sha256 = "86ec522b57d5a7f30e604153b8a5e0337e8449b3caa34a186200b83fffaf9295",
# strip_prefix = "tensorflow-98a5b3b6d13fcf0b8a43b77dbfc108b8953b23b9",
# urls = [
# "https://github.com/tensorflow/tensorflow/archive/98a5b3b6d13fcf0b8a43b77dbfc108b8953b23b9.tar.gz",
# ],
#)
# For development, one can use a local TF repository instead.
local_repository(
name = "org_tensorflow",
path = "tensorflow",
)
load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace", "tf_bind")
tf_workspace(
path_prefix = "",
tf_repo_name = "org_tensorflow",
)
tf_bind()
# Required for TensorFlow dependency on @com_github_grpc_grpc
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
grpc_deps()
load(
"@build_bazel_rules_apple//apple:repositories.bzl",
"apple_rules_dependencies",
)
apple_rules_dependencies()
load(
"@build_bazel_apple_support//lib:repositories.bzl",
"apple_support_dependencies",
)
apple_support_dependencies()
load("@upb//bazel:repository_defs.bzl", "bazel_version_repository")
bazel_version_repository(name = "bazel_version")
This orphan process is left running after the build fails:
00:00:19 bazel(jax) -XX:+HeapDumpOnOutOfMemoryError -XX:HeapDumpPath=/auto/homes/fav25/.cache/bazel/_bazel_fav25/1c964dca75cdce3cf147a83388d88299
Yes (a) is fine.
For (b) that looks like the correct WORKSPACE change. When you say "crash" what happens? If the build fails, what error do you get?
Our strong suspicion is that this is a bug in LLVM's CPU feature detection code. Your CPU supports AVX512, but your OS and/or hypervisor disable it, and LLVM is most likely getting confused. Reverting the https://github.com/tensorflow/tensorflow/commit/110e869def6ea76b7e159eb781edabafebe3a219#diff-39edefb280021958bce4a9ae4f3ad425 should confirm this, I believe.
Oh, I have a guess what went wrong. Try making sure your tensorflow
checkout has the same release as the version already in the workspace, e.g. git checkout 98a5b3b6d13fcf0b8a43b77dbfc108b8953b23b9
in the jax/tensorflow
directory.
Apologies I had not realised I did not attach the error:
ERROR: /auto/homes/fav25/.cache/bazel/_bazel_fav25/1c964dca75cdce3cf147a83388d88299/external/org_tensorflow/tensorflow/compiler/xla/client/BUILD:108:1: @org_tensorflow//tensorflow/compiler/xla/client:local_client depends on @llvm-project//llvm:support in repository @llvm-project which failed to fetch. no such package '@llvm-project//llvm': unlinkat(/auto/homes/fav25/.cache/bazel/_bazel_fav25/1c964dca75cdce3cf147a83388d88299/external/llvm-project/ac2aaa3788cc5e7e2bd3752ad9f71e37f411bdca.tar.gz) (Permission denied)
ERROR: Analysis of target '//build:install_xla_in_source_tree' failed; build aborted: no such package '@llvm-project//llvm': unlinkat(/auto/homes/fav25/.cache/bazel/_bazel_fav25/1c964dca75cdce3cf147a83388d88299/external/llvm-project/ac2aaa3788cc5e7e2bd3752ad9f71e37f411bdca.tar.gz) (Permission denied)
then:
Traceback (most recent call last):
File "build/build.py", line 380, in <module>
main()
File "build/build.py", line 375, in main
shell(command)
File "build/build.py", line 47, in shell
output = subprocess.check_output(cmd)
File "/usr/lib/python3.6/subprocess.py", line 356, in check_output
**kwargs).stdout
File "/usr/lib/python3.6/subprocess.py", line 438, in run
output=stdout, stderr=stderr)
subprocess.CalledProcessError: Command '['./bazel-2.0.0-linux-x86_64', 'run', '--verbose_failures=true', '--config=short_logs', '--config=mkl_open_source_only', ':install_xla_in_source_tree', '/auto/homes/fav25/jax/build']' returned non-zero exit status 1.
Still getting the same error after git checkout 98a5b3b6d13fcf0b8a43b77dbfc108b8953b23b9
I'm guessing from the path you have an NFS home directory? I don't think Bazel supports that. You might try deleting your bazel cache (~/.cache/bazel
).
That fixed it. Annoyingly it could not compile the modified tensorflow module :
ERROR: /auto/homes/fav25/.cache/bazel/_bazel_fav25/1c964dca75cdce3cf147a83388d88299/external/org_tensorflow/tensorflow/compiler/xla/service/cpu/BUILD:194:1: C++ compilation of rule '@org_tensorflow//tensorflow/compiler/xla/service/cpu:simple_orc_jit' failed (Exit 1)
In file included from external/org_tensorflow/tensorflow/compiler/xla/service/buffer_assignment.h:37:0,
from external/org_tensorflow/tensorflow/compiler/xla/service/compiler.h:30,
from external/org_tensorflow/tensorflow/compiler/xla/service/llvm_compiler.h:20,
from external/org_tensorflow/tensorflow/compiler/xla/service/cpu/compiler_functor.h:24,
from external/org_tensorflow/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h:31,
from external/org_tensorflow/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc:16:
external/org_tensorflow/tensorflow/compiler/xla/service/memory_space_assignment.h:364:3: warning: multi-line comment [-Wcomment]
// / \
^
external/org_tensorflow/tensorflow/compiler/xla/service/memory_space_assignment.h:801:3: warning: multi-line comment [-Wcomment]
// / \ \ \
^
external/org_tensorflow/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc: In function 'llvm::SmallVector<std::__cxx11::basic_string<char>, 0> xla::cpu::{anonymous}::DetectMachineAttributes()':
external/org_tensorflow/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc:68:38: error: no matching function for call to 'llvm::SmallVector<std::__cxx11::basic_string<char>, 0>::push_back(llvm::StringRef&)'
result.push_back(feature_name);
^
In file included from external/llvm-project/llvm/include/llvm/ADT/Twine.h:12:0,
from external/llvm-project/llvm/include/llvm/ADT/Triple.h:12,
from external/org_tensorflow/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h:23,
from external/org_tensorflow/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc:16:
external/llvm-project/llvm/include/llvm/ADT/SmallVector.h:238:8: note: candidate: void llvm::SmallVectorTemplateBase<T, <anonymous> >::push_back(const T&) [with T = std::__cxx11::basic_string<char>; bool <anonymous> = false]
void push_back(const T &Elt) {
^~~~~~~~~
external/llvm-project/llvm/include/llvm/ADT/SmallVector.h:238:8: note: no known conversion for argument 1 from 'llvm::StringRef' to 'const std::__cxx11::basic_string<char>&'
external/llvm-project/llvm/include/llvm/ADT/SmallVector.h:245:8: note: candidate: void llvm::SmallVectorTemplateBase<T, <anonymous> >::push_back(T&&) [with T = std::__cxx11::basic_string<char>; bool <anonymous> = false]
void push_back(T &&Elt) {
^~~~~~~~~
external/llvm-project/llvm/include/llvm/ADT/SmallVector.h:245:8: note: no known conversion for argument 1 from 'llvm::StringRef' to 'std::__cxx11::basic_string<char>&&'
external/org_tensorflow/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc: In constructor 'xla::cpu::SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions&, llvm::CodeGenOpt::Level, bool, bool, llvm::FastMathFlags, xla::LLVMCompiler::ModuleHook, xla::LLVMCompiler::ModuleHook, std::function<void(const llvm::object::ObjectFile&)>)':
external/org_tensorflow/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc:145:66: warning: 'llvm::orc::LegacyRTDyldObjectLinkingLayer::LegacyRTDyldObjectLinkingLayer(llvm::orc::ExecutionSession&, llvm::orc::LegacyRTDyldObjectLinkingLayer::ResourcesGetter, llvm::orc::LegacyRTDyldObjectLinkingLayer::NotifyLoadedFtor, llvm::orc::LegacyRTDyldObjectLinkingLayer::NotifyFinalizedFtor, llvm::orc::LegacyRTDyldObjectLinkingLayer::NotifyFreedFtor)' is deprecated [-Wdeprecated-declarations]
llvm::JITEventListener::createGDBRegistrationListener()) {
^
In file included from external/llvm-project/llvm/include/llvm/Support/AlignOf.h:16:0,
from external/llvm-project/llvm/include/llvm/ADT/SmallVector.h:17,
from external/llvm-project/llvm/include/llvm/ADT/Twine.h:12,
from external/llvm-project/llvm/include/llvm/ADT/Triple.h:12,
from external/org_tensorflow/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h:23,
from external/org_tensorflow/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc:16:
external/llvm-project/llvm/include/llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h:378:3: note: declared here
LLVM_ATTRIBUTE_DEPRECATED(
^
external/org_tensorflow/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc:145:66: warning: 'llvm::orc::LegacyIRCompileLayer<BaseLayerT, CompileFtor>::LegacyIRCompileLayer(BaseLayerT&, CompileFtor, llvm::orc::LegacyIRCompileLayer<BaseLayerT, CompileFtor>::NotifyCompiledCallback) [with BaseLayerT = llvm::orc::LegacyRTDyldObjectLinkingLayer; CompileFtor = std::function<llvm::Expected<std::unique_ptr<llvm::MemoryBuffer> >(llvm::Module&)>; llvm::orc::LegacyIRCompileLayer<BaseLayerT, CompileFtor>::NotifyCompiledCallback = std::function<void(long unsigned int, std::unique_ptr<llvm::Module>)>]' is deprecated [-Wdeprecated-declarations]
llvm::JITEventListener::createGDBRegistrationListener()) {
^
In file included from external/org_tensorflow/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h:26:0,
from external/org_tensorflow/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc:16:
external/llvm-project/llvm/include/llvm/ExecutionEngine/Orc/IRCompileLayer.h:155:1: note: declared here
LegacyIRCompileLayer<BaseLayerT, CompileFtor>::LegacyIRCompileLayer(
^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Here is the modified file simple_orc_jit.cc : https://gist.github.com/franciscovargas/dfa42f5fa61604e57692f483a5426676
I reverted to the diff you shared and then fixed the merge conficts (commented them out). Seems ot be failing due to some lack of compatibility with the current type signature of push_back
should I cast feature_name
?
std::string(feature_name )
should fix it Ill try this and rerun.
Finally got the change. You were correct about the virtualisation issue :
>>> import jax
>>> key = jax.random.PRNGKey(0)
/auto/homes/fav25/jax/jax/lib/xla_bridge.py:116: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
>>> jax.random.normal(key, shape=(3,))
DeviceArray([ 1.8160859 , -0.48262325, 0.33988902], dtype=float32)
This means the "bug" is on my side and I should try and sort this out with the sys admins to udpate Xen to support AVX512 ?
Thanks for the help ! sorry that this issue was on my side and I used up your time debugging it.
It's a bit of both, actually. Your configuration is unusual, in that your CPU supports AVX512 but your hypervisor/OS don't enable it. But LLVM should handle that gracefully, especially given the CPU feature flags clearly say that AVX512 features aren't available. I'll follow up with our compiler team (if nothing else to make it easier to disable...).
But you might also do well asking if your installation can enable AVX512, since it might well be helpful for performance.
Thanks for your assistance tracking this down, by the way. We wouldn't have been able to debug this without your help.
I've just been pointed at this, seeing as your sysadmin knows me, and I'm an x86 maintainer in Xen.
@hawkinsp As to your two hypotheses, you can actually spot the difference. Illegal Instruction / SIGILL is an instruction which isn't available (and in this case, not turned on by the OS), whereas a misaligned memory operand manifests as a General Protection Fault / SIGSEGV.
Whatever is going on here, the feature detection isn't following the rules. Feature dispatching logic is required to look first at the OSXSAVE bit in CPUID, then execute XGETBV to see what the OS has turned on for them (in this case, the missing feature is ZMM. While vmovss (%rax),%xmm24
is operating on a 128bit register, it uses an EVEX encoding to access the %xmm24
register which is outside of the encoding space of AVX2), and then check for CPU instruction flags applicable to the enabled state.
In Xen(Server), we deliberately disable AVX512 by default. This is hiding all relevant CPUID bits, and preventing the kernel from turning it on. Whatever logic exists here is probably assuming the availability AVX512 based on the CPU family/model, which isn't necessarily true even on native.
The reason we turn it off is for performance - AVX512 causes substantial frequency ramping across the entire socket, and it is very easy to make your net performance (and those of the other VMs) worse than not having AVX512 enabled in the first place. (There are some cases where even a single bespoke fully optimised application using the entire system might still be faster with AVX2 than AVX512).
You (well - your sysadmin) can turn AVX512 on, but I wouldn't bet on the end result being faster.
LLVM seems to be doing the right thing https://github.com/llvm/llvm-project/blob/master/llvm/lib/Support/Host.cpp#L1407
We call that from here https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc#L80 I wonder if the cpu name overrides the features somehow.
@andyhhp Thanks so much for your help! I agree, AVX512 might not be a win depending on the workloads, and for a shared machine with a mixture of workloads it probably isn't. Your pointers on feature detection were great, because it allowed us to quickly verify that LLVM's logic was doing the right thing.
@d0k made a change to XLA that we think will fix this issue, and I merged it into JAX. Thanks for the quick fix @d0k !
@franciscovargas could you try building a new jaxlib
from an unmodified jax
Github head and verify it fixes your problem? We'll probably make a new binary release anyway in the next week, but we'll need your help to verify that the issue is fixed.
@hawkinsp it works fine with the new release :
>>> import jax
>>> key = jax.random.PRNGKey(0)
/auto/homes/fav25/jax/jax/lib/xla_bridge.py:116: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
>>> jax.random.normal(key, shape=(3,))
DeviceArray([ 1.8160859 , -0.48262325, 0.33988902], dtype=float32)
I cloned jax from github and built it from source without any modifications, looks like the XLA change works ! thanks for the prompt fix !
@franciscovargas That's great! Thanks for helping us track it down.
Script:
Output
I have followed the installation instructions for CPU using pip and also the installation instructions building from source (cloned the latest version
git clone https://github.com/google/jax
):Linux version:
I could not find the file
_pywrap_xla.so
searched inside the build directory and in the jax folder since the package was installed from there it wont be in site-packages.Most similar binary I found was:
The outcomes are the same when installed using pip.
lscpu output: