Open t-kalinowski opened 5 years ago
Cross posted this to stack overflow: https://stackoverflow.com/questions/55227732/how-to-use-a-pre-trained-keras-model-for-inference-in-tf-data-dataset-map
Hi Tomasz, sorry this got overlooked (not sure I can help though)!
Pragmatically asking, couldn't you just build a "custom head" model ("as usual" so to say), and fit that with keras
compile
& fit
the classical way?
(In older TF versions, I was using eager mode & custom models with __call__
syntax and tfdatasets
because you had to, but now that you can do eager with compile
& fit
that perhaps is an option again?)
Otherwise, could you zoom in the above on the most "likely to work" (or "needed to work") version, so one could concentrate on that?
Hi Sigrid, thanks for taking a look.
First, a little of my motivation for context (in case you see an easier way to achieve my goal). I've trained a model, let's call "Model A". I am happy with Model A, and now I want to train another model, Model B. Model B will take as input the output of Model A. I am not interested in training (at this time) the models A & B end-to-end, I want to focus on just Model B.
Model A was trained using a (relatively substantial) tf.data
pipeline, and since inference with Model A is a stateless transformation, I was hoping to initiall do something like
dataset %>%
dataset_map(function(x) {
model_A$predict_on_batch(x)
}
But of course that doesn't work because predict
is built ontop of a sess.run()
call, which isn't going to work in a dataset_map()
. So, the next attempt was to use it as any other layer-reuse operation in keras
dataset %>%
dataset_map(function(x) {
model_A(x)
}
But that also doesn't work because then there are a bunch of state-full operations in the batch-normalization layers, which tf.data doesn't like. So I'm kind of at an impasse. I guess that the best way is to write out the graph containing the model, perhaps as a tensorflow SavedModel, and then removing all stateful layers using a function like this:
freeze_graph <- reticulate::py_run_string(
'
from tensorflow.python.framework import graph_util as tf_graph_util
def freeze_graph(session, outputs):
"""Freeze the current graph.
Args:
session: Tensorflow sessions containing the graph
outputs: List of output tensors
Returns:
The frozen graph_def.
"""
return tf_graph_util.convert_variables_to_constants(
session, session.graph.as_graph_def(), [x.op.name for x in outputs])',
local = TRUE)$freeze_graph
g2 <- freeze_graph(k_get_session(), model_A$outputs)
but then I'm not sure how to actually use the graph g2
for inference from inside tf.data.Dataset.map()
In any case, I'm kind of at an impasse and any help would be appreciated. In the short-term I'm thinking of just moving on from this approach and just saving off a bunch of cases of Model A inference outputs and then saving them offline as maybe tfrecords and then training Model B off those.
Hi Tomasz,
thanks for explaining! I'm afraid I don't really have a good idea here... You probably tried eager mode, to see if it helps with any of the above - or perhaps there are reasons why eager is not an option?
Hi Sigrid. I have not tried eager mode, and if you think it'll solve my problem I'm eager to try it.
How would I use a dataset in eager mode? Are you suggesting writing the training loop explicitly and using model_B.train_on_batch()
? Or is there some more elegant way to still use fit()
in eager mode?
Hm, this is a bit weird because things that work/don't work with eager have been changing quite a bit since I started trying it with 1.10, so I may get confused ;-) ... but the code for this
https://blogs.rstudio.com/tensorflow/posts/2019-02-07-audio-background/
uses tfdatasets
and was explicitly tested to run in eager as well as static mode, against what was to become 1.13 (just master at that time)...
I've spend the past 2 days trying to figure out how to use a pretrained model as a transformation op inside of
dataset_map
, and in that time I've seen about 20 different error messages.Below is a minimal example that illustrates a few attempts of mine. Any help would be much appreciated.
some other things I've tried that aren't in the example include trying to load the model inside of
dataset_map()
, either with or without specifying the input_tensor as the tensor passed to the model, reaching in a trying to returnmodel$outputs
, usingtf$keras$models$clone_model(()
a few different ways, loading the model inside it's own graph (usingwith(tf$Graph()as_default(), ...)
Some of the other error messages I've seen include things like:
Is this possible?