lionelmessi6410 / ntga

Code for "Neural Tangent Generalization Attacks" (ICML 2021)
Apache License 2.0
42 stars 5 forks source link

'ShapedArray' object has no attribute 'val' #2

Open liuyixin-louis opened 2 years ago

liuyixin-louis commented 2 years ago

Hi, nice work, and thanks for sharing the code. When I was running the code, we encountered the following error.

jax._src.traceback_util.UnfilteredStackTrace: AttributeError: 'ShapedArray' object has no attribute 'val'
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

The detailed output is below

Loading dataset...
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Building model...
Generating NTGA....
  0%|                                                                                                                                                                      | 0/78 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "generate_attack.py", line 228, in <module>
    main()
  File "generate_attack.py", line 195, in main
    nb_iter=args.nb_iter, clip_min=0, clip_max=1, batch_size=args.batch_size)
  File "/home/yila22/prj/ntga/attacks/projected_gradient_descent.py", line 82, in projected_gradient_descent
    fx_train_0, fx_test_0, eps, norm, clip_min, clip_max, targeted, batch_size)
  File "/home/yila22/prj/ntga/attacks/fast_gradient_method.py", line 66, in fast_gradient_method
    targeted)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 427, in cache_miss
    donated_invars=donated_invars, inline=inline)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1560, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1551, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1563, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 606, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/xla.py", line 593, in _xla_call_impl
    *unsafe_map(arg_spec, args))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/linear_util.py", line 262, in memoized_fun
    ans = call(fun, *args)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/xla.py", line 668, in _xla_callable
    fun, abstract_args, pe.debug_info_final(fun, "jit"))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1284, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1262, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 829, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 901, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 1997, in _vjp
    flat_fun, primals_flat, reduce_axes=reduce_axes)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/ad.py", line 115, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/ad.py", line 102, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 505, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "generate_attack.py", line 146, in adv_loss
    ntk_train_train = kernel_fn(x_train, x_train, 'ntk')
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 427, in cache_miss
    donated_invars=donated_invars, inline=inline)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1560, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1551, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1563, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/ad.py", line 318, in process_call
    result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1560, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1551, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1563, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 195, in process_call
    f, in_pvals, app, instantiate=False)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 303, in partial_eval
    out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1560, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1551, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1563, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1072, in process_call
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1262, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2935, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2852, in kernel_fn_x1
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2843, in kernel_fn_kernel
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2843, in kernel_fn_kernel
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2844, in kernel_fn_kernel
    return _set_shapes(init_fn, kernel, out_kernel)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2780, in _set_shapes
    shape1 = _propagate_shape(init_fn, in_kernel.shape1)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2770, in _propagate_shape
    out_shape = tree_map(lambda x: int(x.val), out_shape)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/tree_util.py", line 168, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/tree_util.py", line 168, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2770, in <lambda>
    out_shape = tree_map(lambda x: int(x.val), out_shape)
jax._src.traceback_util.UnfilteredStackTrace: AttributeError: 'ShapedArray' object has no attribute 'val'

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "generate_attack.py", line 228, in <module>
    main()
  File "generate_attack.py", line 195, in main
    nb_iter=args.nb_iter, clip_min=0, clip_max=1, batch_size=args.batch_size)
  File "/home/yila22/prj/ntga/attacks/projected_gradient_descent.py", line 82, in projected_gradient_descent
    fx_train_0, fx_test_0, eps, norm, clip_min, clip_max, targeted, batch_size)
  File "/home/yila22/prj/ntga/attacks/fast_gradient_method.py", line 66, in fast_gradient_method
    targeted)
  File "generate_attack.py", line 146, in adv_loss
    ntk_train_train = kernel_fn(x_train, x_train, 'ntk')
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2935, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2852, in kernel_fn_x1
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2843, in kernel_fn_kernel
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2843, in kernel_fn_kernel
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2844, in kernel_fn_kernel
    return _set_shapes(init_fn, kernel, out_kernel)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2780, in _set_shapes
    shape1 = _propagate_shape(init_fn, in_kernel.shape1)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2770, in _propagate_shape
    out_shape = tree_map(lambda x: int(x.val), out_shape)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2770, in <lambda>
    out_shape = tree_map(lambda x: int(x.val), out_shape)
AttributeError: 'ShapedArray' object has no attribute 'val'
xrose3159 commented 1 year ago

Hi, nice work, and thanks for sharing the code. When I was running the code, we encountered the following error.

jax._src.traceback_util.UnfilteredStackTrace: AttributeError: 'ShapedArray' object has no attribute 'val'
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

The detailed output is below

Loading dataset...
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Building model...
Generating NTGA....
  0%|                                                                                                                                                                      | 0/78 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "generate_attack.py", line 228, in <module>
    main()
  File "generate_attack.py", line 195, in main
    nb_iter=args.nb_iter, clip_min=0, clip_max=1, batch_size=args.batch_size)
  File "/home/yila22/prj/ntga/attacks/projected_gradient_descent.py", line 82, in projected_gradient_descent
    fx_train_0, fx_test_0, eps, norm, clip_min, clip_max, targeted, batch_size)
  File "/home/yila22/prj/ntga/attacks/fast_gradient_method.py", line 66, in fast_gradient_method
    targeted)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 427, in cache_miss
    donated_invars=donated_invars, inline=inline)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1560, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1551, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1563, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 606, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/xla.py", line 593, in _xla_call_impl
    *unsafe_map(arg_spec, args))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/linear_util.py", line 262, in memoized_fun
    ans = call(fun, *args)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/xla.py", line 668, in _xla_callable
    fun, abstract_args, pe.debug_info_final(fun, "jit"))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1284, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1262, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 829, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 901, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 1997, in _vjp
    flat_fun, primals_flat, reduce_axes=reduce_axes)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/ad.py", line 115, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/ad.py", line 102, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 505, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "generate_attack.py", line 146, in adv_loss
    ntk_train_train = kernel_fn(x_train, x_train, 'ntk')
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 427, in cache_miss
    donated_invars=donated_invars, inline=inline)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1560, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1551, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1563, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/ad.py", line 318, in process_call
    result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1560, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1551, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1563, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 195, in process_call
    f, in_pvals, app, instantiate=False)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 303, in partial_eval
    out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1560, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1551, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1563, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1072, in process_call
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1262, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2935, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2852, in kernel_fn_x1
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2843, in kernel_fn_kernel
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2843, in kernel_fn_kernel
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2844, in kernel_fn_kernel
    return _set_shapes(init_fn, kernel, out_kernel)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2780, in _set_shapes
    shape1 = _propagate_shape(init_fn, in_kernel.shape1)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2770, in _propagate_shape
    out_shape = tree_map(lambda x: int(x.val), out_shape)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/tree_util.py", line 168, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/tree_util.py", line 168, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2770, in <lambda>
    out_shape = tree_map(lambda x: int(x.val), out_shape)
jax._src.traceback_util.UnfilteredStackTrace: AttributeError: 'ShapedArray' object has no attribute 'val'

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "generate_attack.py", line 228, in <module>
    main()
  File "generate_attack.py", line 195, in main
    nb_iter=args.nb_iter, clip_min=0, clip_max=1, batch_size=args.batch_size)
  File "/home/yila22/prj/ntga/attacks/projected_gradient_descent.py", line 82, in projected_gradient_descent
    fx_train_0, fx_test_0, eps, norm, clip_min, clip_max, targeted, batch_size)
  File "/home/yila22/prj/ntga/attacks/fast_gradient_method.py", line 66, in fast_gradient_method
    targeted)
  File "generate_attack.py", line 146, in adv_loss
    ntk_train_train = kernel_fn(x_train, x_train, 'ntk')
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2935, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2852, in kernel_fn_x1
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2843, in kernel_fn_kernel
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2843, in kernel_fn_kernel
    out_kernel = kernel_fn(kernel, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
    return kernel_fn(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
    k = f(k, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
    return g(*args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
    fn_out = fn(*canonicalized_args, **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
    **kwargs)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2844, in kernel_fn_kernel
    return _set_shapes(init_fn, kernel, out_kernel)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2780, in _set_shapes
    shape1 = _propagate_shape(init_fn, in_kernel.shape1)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2770, in _propagate_shape
    out_shape = tree_map(lambda x: int(x.val), out_shape)
  File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2770, in <lambda>
    out_shape = tree_map(lambda x: int(x.val), out_shape)
AttributeError: 'ShapedArray' object has no attribute 'val'

I'm having the same issue too! Have you solved this problem? How was it solved?