@david-berthelot, With references to this, I've added changes to the momentum module where we can override the default momentum value.
I'm facing 'The problem arose with the 'bool' function. ' issue because of assigning momentum = momentum or self.momentum inside the __call__. I remember because of this I was doing self.momentum = momentum.
The detailed error log is mentioned below.
self = <optimizer.TestOptimizers testMethod=test_square_momentum_override>
def test_square_momentum_override(self):
"""Test logistic loss for momentum optimizer."""
> model_vars, loss = self._test_loss_opt('square', 'momentum', True)
tests/optimizer.py:140:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/optimizer.py:115: in _test_loss_opt
self._check_run(gv, opt, loss, lr, self.num_steps, tolerance, momentum)
tests/optimizer.py:100: in _check_run
opt(lr, g, momentum)
objax/module.py:217: in __call__
output, changes = self._call(self.vc.tensors(), kwargs, *args)
/usr/local/lib/python3.6/dist-packages/jax/api.py:215: in f_jitted
donated_invars=donated_invars)
/usr/local/lib/python3.6/dist-packages/jax/core.py:1144: in bind
return call_bind(self, fun, *args, **params)
/usr/local/lib/python3.6/dist-packages/jax/core.py:1135: in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
/usr/local/lib/python3.6/dist-packages/jax/core.py:1147: in process
return trace.process_call(self, fun, tracers, params)
/usr/local/lib/python3.6/dist-packages/jax/core.py:577: in process_call
return primitive.impl(f, *tracers, **params)
/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py:530: in _xla_call_impl
*unsafe_map(arg_spec, args))
/usr/local/lib/python3.6/dist-packages/jax/linear_util.py:234: in memoized_fun
ans = call(fun, *args)
/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py:595: in _xla_callable
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
/usr/local/lib/python3.6/dist-packages/jax/interpreters/partial_eval.py:1023: in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
/usr/local/lib/python3.6/dist-packages/jax/interpreters/partial_eval.py:1004: in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
/usr/local/lib/python3.6/dist-packages/jax/linear_util.py:151: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
objax/module.py:207: in jit
return f(*args, **kwargs), self.vc.tensors(BaseState)
objax/optimizer/momentum.py:50: in __call__
momentum = momentum or self.momentum
/usr/local/lib/python3.6/dist-packages/jax/core.py:507: in __bool__
def __bool__(self): return self.aval._bool(self)
/usr/local/lib/python3.6/dist-packages/jax/core.py:864: in error
raise_concretization_error(arg, fname_context)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
val = Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, context = 'The problem arose with the `bool` function. '
def raise_concretization_error(val: Tracer, context=""):
msg = ("Abstract tracer value encountered where concrete value is expected.\n\n"
+ context + "\n\n"
+ val._origin_msg() + "\n\n"
+ "You can use transformation parameters such as `static_argnums` for "
"`jit` to avoid tracing particular arguments of transformed functions.\n\n"
"See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.\n\n"
f"Encountered tracer value: {val}")
> raise ConcretizationTypeError(msg)
E jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.
E
E The problem arose with the `bool` function.
E
E The error occured while tracing the function jit at /home/data/sathish/satz_objax/objax/objax/module.py:203.
E
E You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.
E
E See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.
E
E Encountered tracer value: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
/usr/local/lib/python3.6/dist-packages/jax/core.py:853: ConcretizationTypeError
@david-berthelot, With references to this, I've added changes to the momentum module where we can override the default momentum value. I'm facing
'The problem arose with the 'bool' function. '
issue because of assigningmomentum = momentum or self.momentum
inside the__call__
. I remember because of this I was doingself.momentum = momentum
. The detailed error log is mentioned below.