google / objax

Apache License 2.0
769 stars 77 forks source link

Override default momentum #92

Closed sathish-a closed 4 years ago

sathish-a commented 4 years ago

@david-berthelot, With references to this, I've added changes to the momentum module where we can override the default momentum value. I'm facing 'The problem arose with the 'bool' function. ' issue because of assigning momentum = momentum or self.momentum inside the __call__. I remember because of this I was doing self.momentum = momentum. The detailed error log is mentioned below.

self = <optimizer.TestOptimizers testMethod=test_square_momentum_override>

    def test_square_momentum_override(self):
        """Test logistic loss for momentum optimizer."""
>       model_vars, loss = self._test_loss_opt('square', 'momentum', True)

tests/optimizer.py:140:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/optimizer.py:115: in _test_loss_opt
    self._check_run(gv, opt, loss, lr, self.num_steps, tolerance, momentum)
tests/optimizer.py:100: in _check_run
    opt(lr, g, momentum)
objax/module.py:217: in __call__
    output, changes = self._call(self.vc.tensors(), kwargs, *args)
/usr/local/lib/python3.6/dist-packages/jax/api.py:215: in f_jitted
    donated_invars=donated_invars)
/usr/local/lib/python3.6/dist-packages/jax/core.py:1144: in bind
    return call_bind(self, fun, *args, **params)
/usr/local/lib/python3.6/dist-packages/jax/core.py:1135: in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
/usr/local/lib/python3.6/dist-packages/jax/core.py:1147: in process
    return trace.process_call(self, fun, tracers, params)
/usr/local/lib/python3.6/dist-packages/jax/core.py:577: in process_call
    return primitive.impl(f, *tracers, **params)
/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py:530: in _xla_call_impl
    *unsafe_map(arg_spec, args))
/usr/local/lib/python3.6/dist-packages/jax/linear_util.py:234: in memoized_fun
    ans = call(fun, *args)
/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py:595: in _xla_callable
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
/usr/local/lib/python3.6/dist-packages/jax/interpreters/partial_eval.py:1023: in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
/usr/local/lib/python3.6/dist-packages/jax/interpreters/partial_eval.py:1004: in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
/usr/local/lib/python3.6/dist-packages/jax/linear_util.py:151: in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
objax/module.py:207: in jit
    return f(*args, **kwargs), self.vc.tensors(BaseState)
objax/optimizer/momentum.py:50: in __call__
    momentum = momentum or self.momentum
/usr/local/lib/python3.6/dist-packages/jax/core.py:507: in __bool__
    def __bool__(self): return self.aval._bool(self)
/usr/local/lib/python3.6/dist-packages/jax/core.py:864: in error
    raise_concretization_error(arg, fname_context)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

val = Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, context = 'The problem arose with the `bool` function. '

    def raise_concretization_error(val: Tracer, context=""):
      msg = ("Abstract tracer value encountered where concrete value is expected.\n\n"
             + context + "\n\n"
             + val._origin_msg() + "\n\n"
             + "You can use transformation parameters such as `static_argnums` for "
             "`jit` to avoid tracing particular arguments of transformed functions.\n\n"
             "See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.\n\n"
              f"Encountered tracer value: {val}")
>     raise ConcretizationTypeError(msg)
E     jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.
E
E     The problem arose with the `bool` function.
E
E     The error occured while tracing the function jit at /home/data/sathish/satz_objax/objax/objax/module.py:203.
E
E     You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.
E
E     See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.
E
E     Encountered tracer value: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>

/usr/local/lib/python3.6/dist-packages/jax/core.py:853: ConcretizationTypeError