google / objax

Apache License 2.0
769 stars 77 forks source link

Objax2Tf enhancements #179

Open AlexeyKurakin opened 3 years ago

AlexeyKurakin commented 3 years ago

This issue is tracking progress on addressing some current limitations of Objax2Tf converter:

1. Shape polymorphism.

Right now saving Objax2Tf to Tensorflow SavedModel requires specification of shape of the input (including specific value for batch dimension). Shape polymorphism will allow us to use a None for batch dimension, so generated SavedModel could be used with any batch size.

Shape polymorphism was recently added and should be working in jax2tf, however I didn't manage to make it work with Objax2Tf. We need to investigate why it does not work. We also may need to wait until shape polymorphism support will be improved on JAX side.

2. Update of batch norm parameters (and similar things) in generated Tensorflow model

import objax

m = objax.nn.BatchNorm0D(3)
tfm = objax.util.Objax2Tf(m)

x = np.random.normal(size=(10, 3))
# This should update batch norm variables of generated Tensorflow model
y = tfm(x, training=True)

**3. Investigate if generated Objax2Tf model can be trained or fine tuned in Tensorflow.