google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.28k stars 226 forks source link

AttributeError: module 'neural_tangents' has no attribute 'utils' #173

Open LeavesLei opened 1 year ago

LeavesLei commented 1 year ago

Hi developers, I've met a problem when using neural-tangents as follows:

KERNEL_FN = nt.utils.batch.batch(KERNEL_FN, batch_size=kernel_batch_size)
AttributeError: module 'neural_tangents' has no attribute 'utils'

There are the versions of some library:

romanngg commented 1 year ago

We had a refactoring a while ago, please try nt.batch

See https://github.com/google/neural-tangents/blob/main/neural_tangents/__init__.py for the public API

LeavesLei commented 1 year ago

Thanks for your fast reply. I changed nt.utilts.batch.batch() to nt.batch(), but another error occured as follows:

Traceback (most recent call last):                                                                                                           
  File "eval_distilled_set.py", line 190, in <module>                                                                                         
    main()                                                                                                                                    
  File "eval_distilled_set.py", line 156, in main                                                                                             
    K_zz = KERNEL_FN(X_sup_reordered, X_sup_reordered)                                                                                        
  File "/usr/local/lib/python3.8/dist-packages/neural_tangents/_src/utils/utils.py", line 188, in h                                           
    return g(*args, **kwargs)                                                                                                                
  File "/usr/local/lib/python3.8/dist-packages/neural_tangents/_src/batching.py", line 471, in serial_fn                                      
    return serial_fn_x1(x1_or_kernel, x2, *args, **kwargs)                                                                                    
  File "/usr/local/lib/python3.8/dist-packages/neural_tangents/_src/batching.py", line 398, in serial_fn_x1                                   
    _, kernel = _scan(row_fn, 0, (x1s, kwargs_np1))                                                                                           
  File "/usr/local/lib/python3.8/dist-packages/neural_tangents/_src/batching.py", line 151, in _scan                                          
    carry, y = f(carry, x)                                                                                                                    
  File "/usr/local/lib/python3.8/dist-packages/neural_tangents/_src/batching.py", line 387, in row_fn                                         
    return _, _scan(col_fn, x1, (x2s, kwargs_np2))[1]                                                                                         
  File "/usr/local/lib/python3.8/dist-packages/neural_tangents/_src/batching.py", line 151, in _scan                                          
    carry, y = f(carry, x)                                                                                                                    
  File "/usr/local/lib/python3.8/dist-packages/neural_tangents/_src/batching.py", line 396, in col_fn                                         
    return (x1, kwargs1), kernel_fn(x1, x2, *args, **kwargs_merge)                                                                            
  File "/usr/local/lib/python3.8/dist-packages/neural_tangents/_src/utils/utils.py", line 188, in h                                           
    return g(*args, **kwargs)                                                                                                                 
  File "/usr/local/lib/python3.8/dist-packages/neural_tangents/_src/batching.py", line 758, in f_pmapped                                      
    return _f(x_or_kernel, *args_np, **kwargs_np)                                                                                             
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: RET_CHECK failure (external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
:627) dnn != nullptr

where KERNEL_FN = functools.partial(kernel_fn, get=('nngp', 'ntk')).

romanngg commented 1 year ago

Haven't seen this error before, does it still happen if you reduce the batch size? I sometimes encounter low-level XLA errors when running out of memory.

LeavesLei commented 1 year ago

I've redunced the batch size from 25 to 5, but the error still occured. I guess the mismatch between cudnn version and jax caused the problem due to the dnn != nullptr? (https://github.com/google/jax/issues/14480)

I am using Ubuntu 20.04, CUDA 11.4, cudnn 8.7.0, and GPU is TITAN V (12GB).

romanngg commented 1 year ago

Good catch, could be, what are your jax and jaxlib [edit: and nvidia driver] versions?

LeavesLei commented 1 year ago
import jax, jaxlib
jax.__version__: 0.4.4
jaxlib.__version__: 0.4.4

NVIDIA-SMI 470.161.03, Driver Version: 470.161.03

romanngg commented 1 year ago

Hm, these all seem compatible per https://docs.nvidia.com/deeplearning/cudnn/support-matrix/index.html Have you tried updating per https://github.com/google/jax/issues/14480#issuecomment-1431697859 ?

LeavesLei commented 1 year ago

Hi, Roman

Thanks for your reply, and I'll try to update the cuDNN version to solve the problem.

Best, Shiye