When I run the training script, I encounter the problems as follow.
By the way, it's ok when I run training script from nerfies with similar config.
train.py:302] Starting training
Traceback (most recent call last):
File "/root/anaconda3/lib/python3.9/site-packages/jax/api_util.py", line 146, in argnums_partial_except
hash(static_arg)
TypeError: unhashable type: 'NerfModel'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/bfs/sz/Research/hypernerf/train.py", line 370, in <module>
app.run(main)
File "/root/anaconda3/lib/python3.9/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/root/anaconda3/lib/python3.9/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "/bfs/sz/Research/hypernerf/train.py", line 330, in main
state, stats, keys, model_out = ptrain_step(
File "/root/anaconda3/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/root/anaconda3/lib/python3.9/site-packages/jax/_src/api.py", line 1669, in f_pmapped
out = pxla.xla_pmap(
File "/root/anaconda3/lib/python3.9/site-packages/jax/core.py", line 1620, in bind
return call_bind(self, fun, *args, **params)
File "/root/anaconda3/lib/python3.9/site-packages/jax/core.py", line 1551, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/root/anaconda3/lib/python3.9/site-packages/jax/core.py", line 1623, in process
return trace.process_map(self, fun, tracers, params)
File "/root/anaconda3/lib/python3.9/site-packages/jax/core.py", line 606, in process_call
return primitive.impl(f, *tracers, **params)
File "/root/anaconda3/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 624, in xla_pmap_impl
compiled_fun, fingerprint = parallel_callable(fun, backend, axis_name, axis_size,
File "/root/anaconda3/lib/python3.9/site-packages/jax/linear_util.py", line 262, in memoized_fun
ans = call(fun, *args)
File "/root/anaconda3/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 712, in parallel_callable
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
File "/root/anaconda3/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1284, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/root/anaconda3/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1262, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/root/anaconda3/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/root/anaconda3/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/root/anaconda3/lib/python3.9/site-packages/jax/_src/api.py", line 413, in cache_miss
f, args = argnums_partial_except(f, static_argnums, args, allow_invalid=True)
File "/root/anaconda3/lib/python3.9/site-packages/jax/api_util.py", line 148, in argnums_partial_except
raise ValueError(
jax._src.traceback_util.UnfilteredStackTrace: ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 0) of type <class 'hypernerf.models.NerfModel'> for function train_step is non-hashable.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/bfs/sz/Research/hypernerf/train.py", line 370, in <module>
app.run(main)
File "/root/anaconda3/lib/python3.9/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/root/anaconda3/lib/python3.9/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "/bfs/sz/Research/hypernerf/train.py", line 330, in main
state, stats, keys, model_out = ptrain_step(
File "/root/anaconda3/lib/python3.9/site-packages/jax/api_util.py", line 148, in argnums_partial_except
raise ValueError(
ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 0) of type <class 'hypernerf.models.NerfModel'> for function train_step is non-hashable.
When I run the training script, I encounter the problems as follow. By the way, it's ok when I run training script from
nerfies
with similar config.the jax version is as below: