You can assemble such a graph with TFX by defining multiple Trainer components, as shown in the following code example:
# Function to instantiate the Trainer efficiently
def set_trainer(module_file, instance_name, train_steps=5000, eval_steps=100):
return Trainer(
module_file=module_file,
custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
schema=schema.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=train_steps),
eval_args=trainer_pb2.EvalArg(num_steps=eval_steps),
instance_name=instance_name
)
# Load module for each Trainer.
prod_module_file = os.path.join(pipeline_dir, 'prod_module.py')
trial_module_file = os.path.join(pipeline_dir, 'trail_module.py')
...
# Instantiate a Trainer component for each graph branch
trainer_prod_model = set_trainer(prod_module_file, 'production_model')
trainer_trial_model = set_trainer(trail_module_file, 'trial_model',
train_steps=10000,eval_steps=500)
Each instantiated training component needs to be consumed by its own Evaluator, as shown in the following code example. Afterward, the models can be pushed by its own Pusher components:
- TFLite
```python
from tfx.components.trainer.executor import TrainerFnArgs
from tfx.components.trainer.rewriting import converters
from tfx.components.trainer.rewriting import rewriter
from tfx.components.trainer.rewriting import rewriter_factory
def run_fn(fn_args: TrainerFnArgs):
...
temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, "temp")
# Export the model as a saved model
model.save(temp_saving_model_dir,
save_format='tf',
signatures=signatures)
# Instantiate the TFLite rewriter.
tfrw = rewriter_factory.create_rewriter(
rewriter_factory.TFLITE_REWRITER,
name='tflite_rewriter',
enable_experimental_new_converter=True
)
# Convert the model to TFLite format
converters.rewriter_saved_model(temp_saving_model_dir,
fn_args.serving_model_dir,
tfrw,
rewriter.ModelType.TFLITE_MODEL)
# Delete the saved model after conversion
tf.io.gfile.rmtree(temp_saving_model_dir)
Instead of exporting a saved model after the training, we convert the saved model to a TFLite-compatible model and delete the saved model after exporting it. Our Trainer component the exports and registers the TFLite model with the metadata store. The downstream components like the Evaluator or the Pusher can then consume the TFLite-compliant model. The following example shows how we can evaluate the TFLite model, which is helpful in detecting whether the model optimizations have led to a degradation of the model's performance:
- Converting Models to TensorFlow.js: Since TFX version 0.22, an additional feature of the rewriter_factory is available: the conversion of preexisting TensorFlow models to TensorFlow.js models. This conversion allows the deployment of models to web browers and Node.js runtime environments. You can use this new functionality by replacing the `rewriter_factory` name with `rewriter_factory.TFJS_REWRITER` and set the `rewriter.ModelType` to `rewriter.ModelType.TFJS_MODEL`
#### Estimate
#### Tests
Description
Actions
evaluator_trial_model = Evaluator( examples=example_gen.outputs['examples'], model=trainer_trial_model.outputs['model'], eval_config=eval_config_trial_model, instance_name='trial_model' )
evaluation = Evaluator( examples=example_gen.outputs['examples'], model=trainer_mobile_model.outputs['model'], eval_config=eval_config, instance_name="tflite_model" )