Open mwitiderrick opened 2 years ago
On Kaggle notebooks saving the model generates this error model.saved_model(x_sample, "saved-models/high-level")
model.saved_model(x_sample, "saved-models/high-level")
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) /tmp/ipykernel_33/3229665597.py in <module> ----> 1 model.saved_model(x_sample, "saved-models/high-level") /opt/conda/lib/python3.7/site-packages/elegy/model/model_core.py in saved_model(self, inputs, path, batch_size) 769 enable_xla=True, 770 compile_model=False, --> 771 save_model_options=None, 772 ) /opt/conda/lib/python3.7/site-packages/elegy/model/utils.py in convert_and_save_model(jax_fn, params, model_dir, input_signatures, shape_polymorphic_input_spec, with_gradient, enable_xla, compile_model, save_model_options) 99 signatures[ 100 tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY --> 101 ] = tf_fun.get_concrete_function(input_signatures[0]) 102 103 for input_signature in input_signatures[1:]: /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in get_concrete_function(self, *args, **kwargs) 1231 def get_concrete_function(self, *args, **kwargs): 1232 # Implements GenericFunction.get_concrete_function. -> 1233 concrete = self._get_concrete_function_garbage_collected(*args, **kwargs) 1234 concrete._garbage_collector.release() # pylint: disable=protected-access 1235 return concrete /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs) 1211 if self._stateful_fn is None: 1212 initializers = [] -> 1213 self._initialize(args, kwargs, add_initializers_to=initializers) 1214 self._initialize_uninitialized_variables(initializers) 1215 /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to) 758 self._concrete_stateful_fn = ( 759 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access --> 760 *args, **kwds)) 761 762 def invalid_creator_scope(*unused_args, **unused_kwds): /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs) 3064 args, kwargs = None, None 3065 with self._lock: -> 3066 graph_function, _ = self._maybe_define_function(args, kwargs) 3067 return graph_function 3068 /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs) 3461 3462 self._function_cache.missed.add(call_context_key) -> 3463 graph_function = self._create_graph_function(args, kwargs) 3464 self._function_cache.primary[cache_key] = graph_function 3465 /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes) 3306 arg_names=arg_names, 3307 override_flat_arg_shapes=override_flat_arg_shapes, -> 3308 capture_by_value=self._capture_by_value), 3309 self._function_attributes, 3310 function_spec=self.function_spec, /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses) 1005 _, original_func = tf_decorator.unwrap(python_func) 1006 -> 1007 func_outputs = python_func(*func_args, **func_kwargs) 1008 1009 # invariant: `func_outputs` contains only Tensors, CompositeTensors, /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds) 666 # the function a weak reference to itself to avoid a reference cycle. 667 with OptionalXlaContext(compile_with_xla): --> 668 out = weak_wrapped_fn().__wrapped__(*args, **kwds) 669 return out 670 /opt/conda/lib/python3.7/site-packages/elegy/model/utils.py in <lambda>(inputs) 90 ) 91 tf_fun = tf.function( ---> 92 lambda inputs: tf_fn(param_vars, inputs), 93 autograph=False, 94 experimental_compile=compile_model, /opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in converted_fun(*args, **kwargs) 441 else: 442 out_with_avals = _interpret_fun(flat_fun, args_flat, args_avals_flat, --> 443 name_stack, fresh_constant_cache=True) 444 outs, out_avals = util.unzip2(out_with_avals) 445 message = ("The jax2tf-converted function does not support gradients. " /opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in _interpret_fun(fun, in_vals, in_avals, extra_name_stack, fresh_constant_cache) 511 out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = \ 512 _call_wrapped_with_new_constant_cache(fun, in_vals, --> 513 fresh_constant_cache=fresh_constant_cache) 514 515 del main /opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in _call_wrapped_with_new_constant_cache(fun, in_vals, fresh_constant_cache) 530 531 out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = \ --> 532 fun.call_wrapped(*in_vals) 533 finally: 534 if prev_constant_cache is not None and not fresh_constant_cache: /opt/conda/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs) 166 167 try: --> 168 ans = self.f(*args, **dict(self.params, **kwargs)) 169 except: 170 # Some transformations yield from inside context managers, so we have to /opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in fun_no_kwargs(*args_and_kwargs) 284 kwargs = {kw: args_and_kwargs[nr_positional_args + i] 285 for i, kw in enumerate(kw_names)} --> 286 return fun(*args, **kwargs) 287 288 def check_arg(a): /opt/conda/lib/python3.7/site-packages/elegy/model/model_core.py in jax_fn(flat_states, inputs) 755 756 y_pred, _ = model.pred_step( --> 757 inputs=inputs, 758 ) 759 /opt/conda/lib/python3.7/site-packages/elegy/model/model.py in pred_step(self, inputs) 196 inputs_obj = tx.Inputs.from_value(inputs) 197 --> 198 preds = model.module(*inputs_obj.args, **inputs_obj.kwargs) 199 200 return preds, model /opt/conda/lib/python3.7/site-packages/treex/module.py in new_call(self, *args, **kwargs) 114 @functools.wraps(cls.__call__) 115 def new_call(self: Module, *args, **kwargs): --> 116 outputs = orig_call(self, *args, **kwargs) 117 118 if ( /opt/conda/lib/python3.7/site-packages/treeo/api.py in wrapper(tree, *args, **kwargs) 518 def wrapper(tree, *args, **kwargs): 519 with tree_m._COMPACT_CONTEXT.compact(f, tree): --> 520 return f(tree, *args, **kwargs) 521 522 wrapper._treeo_compact = True /tmp/ipykernel_33/786988758.py in __call__(self, x) 14 x = eg.nn.Flatten()(x) 15 # first layers ---> 16 x = eg.nn.Linear(self.n1)(x) 17 x = jax.nn.relu(x) 18 # first layers /opt/conda/lib/python3.7/site-packages/treex/module.py in new_call(self, *args, **kwargs) 114 @functools.wraps(cls.__call__) 115 def new_call(self: Module, *args, **kwargs): --> 116 outputs = orig_call(self, *args, **kwargs) 117 118 if ( /opt/conda/lib/python3.7/site-packages/treex/nn/linear.py in __call__(self, x) 119 params["bias"] = self.bias 120 --> 121 output = self.module.apply({"params": params}, x) 122 return tp.cast(jnp.ndarray, output) [... skipping hidden 7 frame] /opt/conda/lib/python3.7/site-packages/flax/linen/linear.py in __call__(self, inputs) 188 y = lax.dot_general(inputs, kernel, 189 (((inputs.ndim - 1,), (0,)), ((), ())), --> 190 precision=self.precision) 191 if self.use_bias: 192 bias = self.param('bias', self.bias_init, (self.features,), [... skipping hidden 3 frame] /opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in process_primitive(self, primitive, tracers, params) 843 val_out = invoke_impl() 844 else: --> 845 val_out = invoke_impl() 846 847 if primitive.multiple_results: /opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in invoke_impl() 809 _in_avals=args_avals, # type: ignore 810 _out_aval=out_aval, --> 811 **params) 812 else: 813 return impl(*args_tf, **params) /opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in _dot_general(lhs, rhs, dimension_numbers, precision, preferred_element_type, _in_avals, _out_aval) 1595 precision_config_proto, 1596 preferred_element_type=preferred_element_type, -> 1597 use_v2=True) 1598 if _WRAP_JAX_JIT_WITH_TF_FUNCTION: 1599 res = tf.stop_gradient(res) # See #7839 TypeError: dot_general() got an unexpected keyword argument 'use_v2'
On Kaggle notebooks saving the model generates this error
model.saved_model(x_sample, "saved-models/high-level")