PAIR-code / what-if-tool

Source code/webpage/demos for the What-If Tool
https://pair-code.github.io/what-if-tool
Apache License 2.0
892 stars 167 forks source link

Use WIT for model trained in tfx #37

Open orenkobo opened 4 years ago

orenkobo commented 4 years ago

Hi I trained a model with tfx and it was exported as saved_model.pb. Now, I want to reload it and visualize it using WIT. How can I do this?

I couldn't find a way to do it since when reloading the model: imported = tf.saved_model.load(export_dir=trained_model_path) I get object from the type : <tensorflow.python.training.tracking.tracking.AutoTrackable at 0x7f3d71e456a0> instead of an estimator.

Thanks

jameswex commented 4 years ago

Looking at official documentation (https://www.tensorflow.org/guide/saved_model#savedmodels_from_estimators), it seems that when you load a saved model from disk, what you get back is not an estimator. But you should still be able to call predict on that object, by defining your own custom prediction function like is done in that documentation and then providing that custom predict function to the WitConfigBuilder.

Let me know if an approach similar to the predict(x) function in that link works for you.

orenkobo commented 4 years ago

@jameswex When using the predict function:

def predict(x):

    example = tf.train.Example()
    example.features.feature["x"].float_list.value.extend([x])
    return imported.signatures["predict"](examples=tf.constant([example.SerializeToString()]))

config_builder = WitConfigBuilder(test_examples, feats + ['level']).set_estimator_and_feature_spec(predict, feature_spec = [])
WitWidget(config_builder, height=1600)

(With imported being imported = tf.saved_model.load(export_dir=trained_model_path) from the type <tensorflow.python.training.tracking.tracking.AutoTrackable at 0x7f3d71e456a0> )

I get the error: "<_Rendezvous of RPC that terminated with: status = StatusCode.UNAVAILABLE details = "DNS resolution failed" debug_error_string = "{"created":"@1578211571.031196087","description":"Failed to pick subchannel","file":"src/core/ext/filters/client_channel/client_channel.cc","file_line":3818,"referenced_errors":[{"created":"@1578211571.031189371","description":"Resolver transient failure","file":"src/core/ext/filters/client_channel/resolving_lb_policy.cc","file_line":268,"referenced_errors":[{"created":"@1578211571.031187685","description":"DNS resolution failed","file":"src/core/ext/filters/client_channel/resolver/dns/c_ares/dns_resolver_ares.cc","file_line":357,"grpc_status":14,"referenced_errors":[{"created":"@1578211571.031167691","description":"C-ares status is not ARES_SUCCESS: Domain name not found","file":"src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.cc","file_line":244}]}]}]}" >"

jameswex commented 4 years ago

Since you have defined your own custom prediction function, instead of using a tf.Estimator, you want to change your code to something like: config_builder = WitConfigBuilder(test_examples, feats + ['level']).set_custom_predict_fn(predict) WitWidget(config_builder, height=1600)

orenkobo commented 4 years ago

@jameswex OK this is better now but I have a problem - my features are from type list:

features {
  feature {
    key: "b_number"
    value {
      int64_list {
        value: 1
      }
    }
  }
  feature {
    key: "c_type"
    value {
      bytes_list {
        value: "motor"
      }
    }
  }

So I get the error: [features { feature { key: "bearing_number" value { int64_list { value: 1 has type list, but expected one of: int, long, float

I have total of more 30 features and they are all from types float_list / int_list / bytes_list, what is the best way to convert them all to int / long / float?

jameswex commented 4 years ago

Are you able to share a colab notebook with your code that loads up your saved model so I could see the issue? I'm imagining that perhaps the saved model as reloaded wants the example in a very different format than the tf.Example format and so some conversion function will be necessary but its hard to know what that will need to be without playing with it myself.

orenkobo commented 4 years ago

@jameswex It's internal code so it will be problematic to share.. I'll try to play with it and make it work, Thanks!

jameswex commented 4 years ago

Looking at the example in the link I sent above, it seems your custom predict fn might need to take the provided tf.Examples, serialize them and wrap them in a tf.constant like: def predict(examples): return imported.signatures["predict"]( examples=tf.constant([ex.SerializeToString() for ex in examples])) That would be due to how the restored saved model accepts inputs. But I haven't directly worked with this type of restored model before.