tensorflow / decision-forests

A collection of state-of-the-art algorithms for the training, serving and interpretation of Decision Forest models in Keras.
Apache License 2.0
660 stars 110 forks source link

Models trained with fit_on_dataset_path behave unexpectedly #187

Closed rstz closed 1 year ago

rstz commented 1 year ago
          Hey @rstz ,

Thank you for your swift response and your interest in the problem. I think we are talking about the serialization/deserialization problem (and not necessarily about data format itself). I've put together a simple Collab: https://colab.research.google.com/drive/19sepbkGXwM8lI6fZuovvYRAVSiYCNiKl?usp=sharing so that you may have more details and play around. I've tested all this with all versions (as mentioned in my previous comment, with various systems as well) -> can't figure out what is wrong. I hope this will give much more feedback and we will be able to quickly solve the problem (if there is any).

Thank you !!

Originally posted by @piotrlaczkowski in https://github.com/tensorflow/decision-forests/issues/136#issuecomment-1654145924

achoum commented 1 year ago

Unfortunately in TensorFlow, a loaded model is not equivalent to an original model. An original model contains several utilities to automatically convert and adapt dataset format, while a loaded model is very restrictive and specific.

The error:

ValueError: Could not find matching concrete function to call loaded from the SavedModel. Got:
      Positional arguments (2 total):
        * OrderedDict([('MPG', <tf.Tensor 'inputs:0' shape=(None, 1) dtype=float32>),
                 ('Cylinders',
                  <tf.Tensor 'inputs_1:0' shape=(None, 1) dtype=int32>),
                 ('Displacement',
                  <tf.Tensor 'inputs_2:0' shape=(None, 1) dtype=float32>),
                 ('Horsepower',
                  <tf.Tensor 'inputs_3:0' shape=(None, 1) dtype=float32>),
                 ('Weight', <tf.Tensor 'inputs_4:0' shape=(None, 1) dtype=float32>),
                 ('Acceleration',
                  <tf.Tensor 'inputs_5:0' shape=(None, 1) dtype=float32>),
                 ('year', <tf.Tensor 'inputs_6:0' shape=(None, 1) dtype=int32>),
                 ('Origin', <tf.Tensor 'inputs_7:0' shape=(None, 1) dtype=int32>)])
        * False
      Keyword arguments: {}

     Expected these arguments to match one of the following 2 option(s):

    Option 1:
      Positional arguments (2 total):
        * {'Acceleration': TensorSpec(shape=(None,), dtype=tf.float32, name='Acceleration'),
     'Cylinders': TensorSpec(shape=(None,), dtype=tf.float32, name='Cylinders'),
     'Horsepower': TensorSpec(shape=(None,), dtype=tf.float32, name='Horsepower'),
     'MPG': TensorSpec(shape=(None,), dtype=tf.float32, name='MPG'),
     'Origin': TensorSpec(shape=(None,), dtype=tf.float32, name='Origin'),
     'Weight': TensorSpec(shape=(None,), dtype=tf.float32, name='Weight'),
     'year': TensorSpec(shape=(None,), dtype=tf.float32, name='year')}
        * True
      Keyword arguments: {}

Indicates that the provided dataset does not match what the model expects:

Here is a proposed fix of those 3 issues:

for batch in ds.batch(3).take(1):

    clean_batch = dict(batch)

    for k in clean_batch:
      clean_batch[k] = tf.cast(tf.squeeze(clean_batch[k]), tf.float32)

    del clean_batch["Displacement"]

    print(clean_batch)
    print(model_loaded.predict(clean_batch))

Result:

{'MPG': <tf.Tensor: shape=(3,), dtype=float32, numpy=array([18. , 30.9, 26. ], dtype=float32)>, 'Cylinders': <tf.Tensor: shape=(3,), dtype=float32, numpy=array([6., 4., 4.], dtype=float32)>, 'Horsepower': <tf.Tensor: shape=(3,), dtype=float32, numpy=array([88., 75., 93.], dtype=float32)>, 'Weight': <tf.Tensor: shape=(3,), dtype=float32, numpy=array([3021., 2230., 2391.], dtype=float32)>, 'Acceleration': <tf.Tensor: shape=(3,), dtype=float32, numpy=array([16.5, 14.5, 15.5], dtype=float32)>, 'year': <tf.Tensor: shape=(3,), dtype=float32, numpy=array([73., 78., 74.], dtype=float32)>, 'Origin': <tf.Tensor: shape=(3,), dtype=float32, numpy=array([1., 1., 3.], dtype=float32)>}
1/1 [==============================] - 0s 51ms/step
[[221.37837 ]
 [110.83273 ]
 [107.609535]]
rstz commented 1 year ago

Closing this, feel free to reopen if there is additional information/ assistance needed