google / hypernerf

Code for "HyperNeRF: A Higher-Dimensional Representation for Topologically Varying Neural Radiance Fields".
https://hypernerf.github.io
Apache License 2.0
896 stars 105 forks source link

ValueError: Non-hashable static arguments are not supported #33

Closed Spark001 closed 2 years ago

Spark001 commented 2 years ago

When I run the training script, I encounter the problems as follow. By the way, it's ok when I run training script from nerfies with similar config.

train.py:302] Starting training
Traceback (most recent call last):
  File "/root/anaconda3/lib/python3.9/site-packages/jax/api_util.py", line 146, in argnums_partial_except
    hash(static_arg)
TypeError: unhashable type: 'NerfModel'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/bfs/sz/Research/hypernerf/train.py", line 370, in <module>
    app.run(main)
  File "/root/anaconda3/lib/python3.9/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/root/anaconda3/lib/python3.9/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "/bfs/sz/Research/hypernerf/train.py", line 330, in main
    state, stats, keys, model_out = ptrain_step(
  File "/root/anaconda3/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/root/anaconda3/lib/python3.9/site-packages/jax/_src/api.py", line 1669, in f_pmapped
    out = pxla.xla_pmap(
  File "/root/anaconda3/lib/python3.9/site-packages/jax/core.py", line 1620, in bind
    return call_bind(self, fun, *args, **params)
  File "/root/anaconda3/lib/python3.9/site-packages/jax/core.py", line 1551, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/root/anaconda3/lib/python3.9/site-packages/jax/core.py", line 1623, in process
    return trace.process_map(self, fun, tracers, params)
  File "/root/anaconda3/lib/python3.9/site-packages/jax/core.py", line 606, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/root/anaconda3/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 624, in xla_pmap_impl
    compiled_fun, fingerprint = parallel_callable(fun, backend, axis_name, axis_size,
  File "/root/anaconda3/lib/python3.9/site-packages/jax/linear_util.py", line 262, in memoized_fun
    ans = call(fun, *args)
  File "/root/anaconda3/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 712, in parallel_callable
    jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
  File "/root/anaconda3/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1284, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/root/anaconda3/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1262, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/root/anaconda3/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/root/anaconda3/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/root/anaconda3/lib/python3.9/site-packages/jax/_src/api.py", line 413, in cache_miss
    f, args = argnums_partial_except(f, static_argnums, args, allow_invalid=True)
  File "/root/anaconda3/lib/python3.9/site-packages/jax/api_util.py", line 148, in argnums_partial_except
    raise ValueError(
jax._src.traceback_util.UnfilteredStackTrace: ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 0) of type <class 'hypernerf.models.NerfModel'> for function train_step is non-hashable.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/bfs/sz/Research/hypernerf/train.py", line 370, in <module>
    app.run(main)
  File "/root/anaconda3/lib/python3.9/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/root/anaconda3/lib/python3.9/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "/bfs/sz/Research/hypernerf/train.py", line 330, in main
    state, stats, keys, model_out = ptrain_step(
  File "/root/anaconda3/lib/python3.9/site-packages/jax/api_util.py", line 148, in argnums_partial_except
    raise ValueError(
ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 0) of type <class 'hypernerf.models.NerfModel'> for function train_step is non-hashable.

the jax version is as below:

jax                                0.2.17
jaxlib                             0.1.66+cuda110
Spark001 commented 2 years ago

Solved by updating package version according to requirements.txt