christiaanjs / treeflow

GNU General Public License v3.0
13 stars 4 forks source link

Can we serialize models? #15

Open matsen opened 4 years ago

matsen commented 4 years ago

I tried pickling a variational approximation by adding this to the treeflow experiment experiments/2020-06-18-getting-started:

diff --git a/experiments/2020-06-18-getting-started/run.py b/experiments/2020-06-18-getting-started/run.py
index 26507d8..9b2c717 100755
--- a/experiments/2020-06-18-getting-started/run.py
+++ b/experiments/2020-06-18-getting-started/run.py
@@ -4,6 +4,7 @@
 # # Tensorflow Probability & Treeflow Demo
 # Author: Christiaan Swanepoel

+import pickle
 import click
 import numpy as np
 import pandas as pd
@@ -116,6 +117,9 @@ with click.progressbar(range(TRIAL_COUNT), label="Trials") as bar:
             log_posterior, q_tmp, tf.optimizers.Adam(learning_rate=0.0001), 5
         )

+with open("test.pkl", "wb") as file:
+    pickle.dump(q_tmp, file)
+
 if USE_TF_PROFILER:
     tf.profiler.experimental.stop()

... and got the error

AttributeError: Can't pickle local object '_prob_chain_rule_model_flatten.<locals>._make.<locals>.<lambda>'

Poking around, it looks like this might not be too hard to fix.

Or is there a better way to serialize such models?

christiaanjs commented 4 years ago

I think we'd have to use named functions instead of lambdas in the JointDistributionNamed constructor. This might be easier with a different JointDistribution concrete class - you could do it with a single named function in JointDistributionCoroutine

matsen commented 4 years ago

By the way, I tried using the dill serializer and got a TypeError: can't pickle HashableWeakRef objects error.

Seems like the way to do this would be to build things up slowly...

matsen commented 4 years ago

This from the TFP mailing list:

Saving models: This was probably the biggest headache for us. At the
commencement of the contest we had to send a saved model to the organizers.
We thought it would be as easy as inheriting tf.Module and using
tensorflow's saved model capabilities. We were wrong. We relied a lot on
using JointDistributionNamed objects and these objects completely refused
to be saved. In order to get our model object to save at all, I had to
exclude all JointDistributionNamed objects using NoDependency which in
essence meant that our model wasn't saved. We started developing the
project before JointDistributionCoroutine was included in TFP - no idea if
that has the same issue. Additionally, other non-tensorflow attributes such
as lists and dictionaries with model attributes didn't get saved when using
saved model. Eventually, we developed our own serialization using pickle.

cc @miparedes