Open liuyixin-louis opened 2 years ago
Hi, nice work, and thanks for sharing the code. When I was running the code, we encountered the following error.
jax._src.traceback_util.UnfilteredStackTrace: AttributeError: 'ShapedArray' object has no attribute 'val' The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.
The detailed output is below
Loading dataset... WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) Building model... Generating NTGA.... 0%| | 0/78 [00:00<?, ?it/s] Traceback (most recent call last): File "generate_attack.py", line 228, in <module> main() File "generate_attack.py", line 195, in main nb_iter=args.nb_iter, clip_min=0, clip_max=1, batch_size=args.batch_size) File "/home/yila22/prj/ntga/attacks/projected_gradient_descent.py", line 82, in projected_gradient_descent fx_train_0, fx_test_0, eps, norm, clip_min, clip_max, targeted, batch_size) File "/home/yila22/prj/ntga/attacks/fast_gradient_method.py", line 66, in fast_gradient_method targeted) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 427, in cache_miss donated_invars=donated_invars, inline=inline) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1560, in bind return call_bind(self, fun, *args, **params) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1551, in call_bind outs = primitive.process(top_trace, fun, tracers, params) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1563, in process return trace.process_call(self, fun, tracers, params) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 606, in process_call return primitive.impl(f, *tracers, **params) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/xla.py", line 593, in _xla_call_impl *unsafe_map(arg_spec, args)) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/linear_util.py", line 262, in memoized_fun ans = call(fun, *args) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/xla.py", line 668, in _xla_callable fun, abstract_args, pe.debug_info_final(fun, "jit")) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1284, in trace_to_jaxpr_final jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1262, in trace_to_subjaxpr_dynamic ans = fun.call_wrapped(*in_tracers) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs)) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 829, in grad_f _, g = value_and_grad_f(*args, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 901, in value_and_grad_f ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 1997, in _vjp flat_fun, primals_flat, reduce_axes=reduce_axes) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/ad.py", line 115, in vjp out_primals, pvals, jaxpr, consts = linearize(traceable, *primals) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/ad.py", line 102, in linearize jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 505, in trace_to_jaxpr jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs)) File "generate_attack.py", line 146, in adv_loss ntk_train_train = kernel_fn(x_train, x_train, 'ntk') File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 427, in cache_miss donated_invars=donated_invars, inline=inline) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1560, in bind return call_bind(self, fun, *args, **params) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1551, in call_bind outs = primitive.process(top_trace, fun, tracers, params) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1563, in process return trace.process_call(self, fun, tracers, params) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/ad.py", line 318, in process_call result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1560, in bind return call_bind(self, fun, *args, **params) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1551, in call_bind outs = primitive.process(top_trace, fun, tracers, params) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1563, in process return trace.process_call(self, fun, tracers, params) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 195, in process_call f, in_pvals, app, instantiate=False) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 303, in partial_eval out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux() File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1560, in bind return call_bind(self, fun, *args, **params) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1551, in call_bind outs = primitive.process(top_trace, fun, tracers, params) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1563, in process return trace.process_call(self, fun, tracers, params) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1072, in process_call jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1262, in trace_to_subjaxpr_dynamic ans = fun.call_wrapped(*in_tracers) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped ans = self.f(*args, **dict(self.params, **kwargs)) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h return g(*args, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn fn_out = fn(*canonicalized_args, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2935, in kernel_fn_any **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2852, in kernel_fn_x1 out_kernel = kernel_fn(kernel, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h return g(*args, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn return kernel_fn(k, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn k = f(k, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h return g(*args, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn fn_out = fn(*canonicalized_args, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2843, in kernel_fn_kernel out_kernel = kernel_fn(kernel, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h return g(*args, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn return kernel_fn(k, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn k = f(k, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h return g(*args, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn fn_out = fn(*canonicalized_args, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2843, in kernel_fn_kernel out_kernel = kernel_fn(kernel, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h return g(*args, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn return kernel_fn(k, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn k = f(k, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h return g(*args, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn fn_out = fn(*canonicalized_args, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2844, in kernel_fn_kernel return _set_shapes(init_fn, kernel, out_kernel) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2780, in _set_shapes shape1 = _propagate_shape(init_fn, in_kernel.shape1) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2770, in _propagate_shape out_shape = tree_map(lambda x: int(x.val), out_shape) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/tree_util.py", line 168, in tree_map return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/tree_util.py", line 168, in <genexpr> return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2770, in <lambda> out_shape = tree_map(lambda x: int(x.val), out_shape) jax._src.traceback_util.UnfilteredStackTrace: AttributeError: 'ShapedArray' object has no attribute 'val' The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified. -------------------- The above exception was the direct cause of the following exception: Traceback (most recent call last): File "generate_attack.py", line 228, in <module> main() File "generate_attack.py", line 195, in main nb_iter=args.nb_iter, clip_min=0, clip_max=1, batch_size=args.batch_size) File "/home/yila22/prj/ntga/attacks/projected_gradient_descent.py", line 82, in projected_gradient_descent fx_train_0, fx_test_0, eps, norm, clip_min, clip_max, targeted, batch_size) File "/home/yila22/prj/ntga/attacks/fast_gradient_method.py", line 66, in fast_gradient_method targeted) File "generate_attack.py", line 146, in adv_loss ntk_train_train = kernel_fn(x_train, x_train, 'ntk') File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h return g(*args, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn fn_out = fn(*canonicalized_args, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2935, in kernel_fn_any **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2852, in kernel_fn_x1 out_kernel = kernel_fn(kernel, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h return g(*args, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn return kernel_fn(k, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn k = f(k, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h return g(*args, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn fn_out = fn(*canonicalized_args, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2843, in kernel_fn_kernel out_kernel = kernel_fn(kernel, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h return g(*args, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn return kernel_fn(k, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn k = f(k, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h return g(*args, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn fn_out = fn(*canonicalized_args, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2843, in kernel_fn_kernel out_kernel = kernel_fn(kernel, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h return g(*args, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn return kernel_fn(k, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn k = f(k, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h return g(*args, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn fn_out = fn(*canonicalized_args, **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any **kwargs) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2844, in kernel_fn_kernel return _set_shapes(init_fn, kernel, out_kernel) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2780, in _set_shapes shape1 = _propagate_shape(init_fn, in_kernel.shape1) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2770, in _propagate_shape out_shape = tree_map(lambda x: int(x.val), out_shape) File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2770, in <lambda> out_shape = tree_map(lambda x: int(x.val), out_shape) AttributeError: 'ShapedArray' object has no attribute 'val'
I'm having the same issue too! Have you solved this problem? How was it solved?
Hi, nice work, and thanks for sharing the code. When I was running the code, we encountered the following error.
The detailed output is below