poets-ai / elegy

A High Level API for Deep Learning in JAX
https://poets-ai.github.io/elegy/
MIT License
469 stars 32 forks source link

TypeError: dot_general() got an unexpected keyword argument 'use_v2' #240

Open mwitiderrick opened 2 years ago

mwitiderrick commented 2 years ago

On Kaggle notebooks saving the model generates this error model.saved_model(x_sample, "saved-models/high-level")

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_33/3229665597.py in <module>
----> 1 model.saved_model(x_sample, "saved-models/high-level")

/opt/conda/lib/python3.7/site-packages/elegy/model/model_core.py in saved_model(self, inputs, path, batch_size)
    769             enable_xla=True,
    770             compile_model=False,
--> 771             save_model_options=None,
    772         )

/opt/conda/lib/python3.7/site-packages/elegy/model/utils.py in convert_and_save_model(jax_fn, params, model_dir, input_signatures, shape_polymorphic_input_spec, with_gradient, enable_xla, compile_model, save_model_options)
     99         signatures[
    100             tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
--> 101         ] = tf_fun.get_concrete_function(input_signatures[0])
    102 
    103         for input_signature in input_signatures[1:]:

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in get_concrete_function(self, *args, **kwargs)
   1231   def get_concrete_function(self, *args, **kwargs):
   1232     # Implements GenericFunction.get_concrete_function.
-> 1233     concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
   1234     concrete._garbage_collector.release()  # pylint: disable=protected-access
   1235     return concrete

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)
   1211       if self._stateful_fn is None:
   1212         initializers = []
-> 1213         self._initialize(args, kwargs, add_initializers_to=initializers)
   1214         self._initialize_uninitialized_variables(initializers)
   1215 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    758     self._concrete_stateful_fn = (
    759         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 760             *args, **kwds))
    761 
    762     def invalid_creator_scope(*unused_args, **unused_kwds):

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   3064       args, kwargs = None, None
   3065     with self._lock:
-> 3066       graph_function, _ = self._maybe_define_function(args, kwargs)
   3067     return graph_function
   3068 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3461 
   3462           self._function_cache.missed.add(call_context_key)
-> 3463           graph_function = self._create_graph_function(args, kwargs)
   3464           self._function_cache.primary[cache_key] = graph_function
   3465 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3306             arg_names=arg_names,
   3307             override_flat_arg_shapes=override_flat_arg_shapes,
-> 3308             capture_by_value=self._capture_by_value),
   3309         self._function_attributes,
   3310         function_spec=self.function_spec,

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses)
   1005         _, original_func = tf_decorator.unwrap(python_func)
   1006 
-> 1007       func_outputs = python_func(*func_args, **func_kwargs)
   1008 
   1009       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    666         # the function a weak reference to itself to avoid a reference cycle.
    667         with OptionalXlaContext(compile_with_xla):
--> 668           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    669         return out
    670 

/opt/conda/lib/python3.7/site-packages/elegy/model/utils.py in <lambda>(inputs)
     90         )
     91         tf_fun = tf.function(
---> 92             lambda inputs: tf_fn(param_vars, inputs),
     93             autograph=False,
     94             experimental_compile=compile_model,

/opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in converted_fun(*args, **kwargs)
    441       else:
    442         out_with_avals = _interpret_fun(flat_fun, args_flat, args_avals_flat,
--> 443                                         name_stack, fresh_constant_cache=True)
    444         outs, out_avals = util.unzip2(out_with_avals)
    445         message = ("The jax2tf-converted function does not support gradients. "

/opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in _interpret_fun(fun, in_vals, in_avals, extra_name_stack, fresh_constant_cache)
    511         out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = \
    512             _call_wrapped_with_new_constant_cache(fun, in_vals,
--> 513                                                   fresh_constant_cache=fresh_constant_cache)
    514 
    515       del main

/opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in _call_wrapped_with_new_constant_cache(fun, in_vals, fresh_constant_cache)
    530 
    531     out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = \
--> 532         fun.call_wrapped(*in_vals)
    533   finally:
    534     if prev_constant_cache is not None and not fresh_constant_cache:

/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    166 
    167     try:
--> 168       ans = self.f(*args, **dict(self.params, **kwargs))
    169     except:
    170       # Some transformations yield from inside context managers, so we have to

/opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in fun_no_kwargs(*args_and_kwargs)
    284       kwargs = {kw: args_and_kwargs[nr_positional_args + i]
    285                 for i, kw in enumerate(kw_names)}
--> 286       return fun(*args, **kwargs)
    287 
    288     def check_arg(a):

/opt/conda/lib/python3.7/site-packages/elegy/model/model_core.py in jax_fn(flat_states, inputs)
    755 
    756             y_pred, _ = model.pred_step(
--> 757                 inputs=inputs,
    758             )
    759 

/opt/conda/lib/python3.7/site-packages/elegy/model/model.py in pred_step(self, inputs)
    196         inputs_obj = tx.Inputs.from_value(inputs)
    197 
--> 198         preds = model.module(*inputs_obj.args, **inputs_obj.kwargs)
    199 
    200         return preds, model

/opt/conda/lib/python3.7/site-packages/treex/module.py in new_call(self, *args, **kwargs)
    114             @functools.wraps(cls.__call__)
    115             def new_call(self: Module, *args, **kwargs):
--> 116                 outputs = orig_call(self, *args, **kwargs)
    117 
    118                 if (

/opt/conda/lib/python3.7/site-packages/treeo/api.py in wrapper(tree, *args, **kwargs)
    518     def wrapper(tree, *args, **kwargs):
    519         with tree_m._COMPACT_CONTEXT.compact(f, tree):
--> 520             return f(tree, *args, **kwargs)
    521 
    522     wrapper._treeo_compact = True

/tmp/ipykernel_33/786988758.py in __call__(self, x)
     14         x = eg.nn.Flatten()(x)
     15         # first layers
---> 16         x = eg.nn.Linear(self.n1)(x)
     17         x = jax.nn.relu(x)
     18         # first layers

/opt/conda/lib/python3.7/site-packages/treex/module.py in new_call(self, *args, **kwargs)
    114             @functools.wraps(cls.__call__)
    115             def new_call(self: Module, *args, **kwargs):
--> 116                 outputs = orig_call(self, *args, **kwargs)
    117 
    118                 if (

/opt/conda/lib/python3.7/site-packages/treex/nn/linear.py in __call__(self, x)
    119             params["bias"] = self.bias
    120 
--> 121         output = self.module.apply({"params": params}, x)
    122         return tp.cast(jnp.ndarray, output)

    [... skipping hidden 7 frame]

/opt/conda/lib/python3.7/site-packages/flax/linen/linear.py in __call__(self, inputs)
    188     y = lax.dot_general(inputs, kernel,
    189                         (((inputs.ndim - 1,), (0,)), ((), ())),
--> 190                         precision=self.precision)
    191     if self.use_bias:
    192       bias = self.param('bias', self.bias_init, (self.features,),

    [... skipping hidden 3 frame]

/opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in process_primitive(self, primitive, tracers, params)
    843           val_out = invoke_impl()
    844       else:
--> 845         val_out = invoke_impl()
    846 
    847     if primitive.multiple_results:

/opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in invoke_impl()
    809             _in_avals=args_avals,  # type: ignore
    810             _out_aval=out_aval,
--> 811             **params)
    812       else:
    813         return impl(*args_tf, **params)

/opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in _dot_general(lhs, rhs, dimension_numbers, precision, preferred_element_type, _in_avals, _out_aval)
   1595       precision_config_proto,
   1596       preferred_element_type=preferred_element_type,
-> 1597       use_v2=True)
   1598   if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
   1599     res = tf.stop_gradient(res)  # See #7839

TypeError: dot_general() got an unexpected keyword argument 'use_v2'