Open LeavesLei opened 1 year ago
We had a refactoring a while ago, please try nt.batch
See https://github.com/google/neural-tangents/blob/main/neural_tangents/__init__.py for the public API
Thanks for your fast reply. I changed nt.utilts.batch.batch()
to nt.batch()
, but another error occured as follows:
Traceback (most recent call last):
File "eval_distilled_set.py", line 190, in <module>
main()
File "eval_distilled_set.py", line 156, in main
K_zz = KERNEL_FN(X_sup_reordered, X_sup_reordered)
File "/usr/local/lib/python3.8/dist-packages/neural_tangents/_src/utils/utils.py", line 188, in h
return g(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/neural_tangents/_src/batching.py", line 471, in serial_fn
return serial_fn_x1(x1_or_kernel, x2, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/neural_tangents/_src/batching.py", line 398, in serial_fn_x1
_, kernel = _scan(row_fn, 0, (x1s, kwargs_np1))
File "/usr/local/lib/python3.8/dist-packages/neural_tangents/_src/batching.py", line 151, in _scan
carry, y = f(carry, x)
File "/usr/local/lib/python3.8/dist-packages/neural_tangents/_src/batching.py", line 387, in row_fn
return _, _scan(col_fn, x1, (x2s, kwargs_np2))[1]
File "/usr/local/lib/python3.8/dist-packages/neural_tangents/_src/batching.py", line 151, in _scan
carry, y = f(carry, x)
File "/usr/local/lib/python3.8/dist-packages/neural_tangents/_src/batching.py", line 396, in col_fn
return (x1, kwargs1), kernel_fn(x1, x2, *args, **kwargs_merge)
File "/usr/local/lib/python3.8/dist-packages/neural_tangents/_src/utils/utils.py", line 188, in h
return g(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/neural_tangents/_src/batching.py", line 758, in f_pmapped
return _f(x_or_kernel, *args_np, **kwargs_np)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: RET_CHECK failure (external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
:627) dnn != nullptr
where KERNEL_FN = functools.partial(kernel_fn, get=('nngp', 'ntk'))
.
Haven't seen this error before, does it still happen if you reduce the batch size? I sometimes encounter low-level XLA errors when running out of memory.
I've redunced the batch size from 25 to 5, but the error still occured. I guess the mismatch between cudnn version and jax caused the problem due to the dnn != nullptr
? (https://github.com/google/jax/issues/14480)
I am using Ubuntu 20.04, CUDA 11.4, cudnn 8.7.0, and GPU is TITAN V (12GB).
Good catch, could be, what are your jax and jaxlib [edit: and nvidia driver] versions?
import jax, jaxlib
jax.__version__: 0.4.4
jaxlib.__version__: 0.4.4
NVIDIA-SMI 470.161.03, Driver Version: 470.161.03
Hm, these all seem compatible per https://docs.nvidia.com/deeplearning/cudnn/support-matrix/index.html Have you tried updating per https://github.com/google/jax/issues/14480#issuecomment-1431697859 ?
Hi, Roman
Thanks for your reply, and I'll try to update the cuDNN version to solve the problem.
Best, Shiye
Hi developers, I've met a problem when using neural-tangents as follows:
There are the versions of some library: