titu1994 / tfdiffeq

Tensorflow implementation of Ordinary Differential Equation Solvers with full GPU support
MIT License
218 stars 52 forks source link

A problem with inserting ODENet Model #8

Closed min443 closed 4 years ago

min443 commented 4 years ago

Hi. According to your README, it's possible to use ODENet as a Layer in the NN.

Here's your example.

Used inside other models

x = Conv2D(...)(x) x = Conv2D(...)(x) x = Flatten()(x) x = ODENet(...)(x) # or dont use flatten and use ConvODENet directly x = ODENet(...)(x) # or dont use flatten and use ConvODENet directly ...

Here's my code.

y = Input(shape=(20,))
y = Dense(10)
y = ODENet(hidden_dim=10, output_dim=10)(y)

However, I get the following error when I run the code.

_AssertionError                            Traceback (most recent call last)
<ipython-input-4-807a3c818cf4> in <module>
      1 y = Input(shape=(20,))
      2 y = Dense(10)
----> 3 y = ODENet(hidden_dim=10, output_dim=10)(y)

~/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py in __call__(self, inputs, *args, **kwargs)
    896           with base_layer_utils.autocast_context_manager(
    897               self._compute_dtype):
--> 898             outputs = self.call(cast_inputs, *args, **kwargs)
    899           self._handle_activity_regularization(inputs, outputs)
    900           self._set_mask_metadata(inputs, outputs, input_masks)

~/tfdiffeq/tfdiffeq/models/dense_odenet.py in call(self, x, training, return_features)
    252     # @tf.function
    253     def call(self, x, training=None, return_features=False):
--> 254         features = self.odeblock(x, training=training)
    255 
    256         pred = self.linear_layer(features)

~/.local/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py in __call__(self, inputs, *args, **kwargs)
    896           with base_layer_utils.autocast_context_manager(
    897               self._compute_dtype):
--> 898             outputs = self.call(cast_inputs, *args, **kwargs)
    899           self._handle_activity_regularization(inputs, outputs)
    900           self._set_mask_metadata(inputs, outputs, input_masks)

~/tfdiffeq/tfdiffeq/models/dense_odenet.py in call(self, x, training, eval_times, **kwargs)
    184             out = odeint(self.odefunc, x_aug, integration_time,
    185                          rtol=self.tol, atol=self.tol, method=self.method,
--> 186                          options=self.options)
    187 
    188         if eval_times is None:

~/tfdiffeq/tfdiffeq/odeint.py in odeint(func, y0, t, rtol, atol, method, options)
     66             an invalid dtype.
     67     """
---> 68     tensor_input, func, y0, t = _check_inputs(func, y0, t)
     69 
     70     if options is None:

~/tfdiffeq/tfdiffeq/misc.py in _check_inputs(func, y0, t)
    303         func = lambda t, y: (_base_nontuple_func_(t, y[0]),)
    304 
--> 305     assert isinstance(y0, tuple), 'y0 must be either a tf.Tensor or a tuple'
    306     if ((type(y0) == tuple) or (type(y0) == list)):
    307         if not tensor_input:

AssertionError: y0 must be either a tf.Tensor or a tuple_

Could you tell me what I did wrong?