databricks / spark-deep-learning

Deep Learning Pipelines for Apache Spark
https://databricks.github.io/spark-deep-learning
Apache License 2.0
1.99k stars 494 forks source link

How to pass custom objects to KerasImageFileTransformer. keras load_model supports custom_objects #162

Open vvivek921 opened 5 years ago

vvivek921 commented 5 years ago

Hi, I am using dice_loss custom object to train my model. Is there anyway to pass custom objects to load model in spark DL? Or is it that spark DL doesn't support loading models which have custom objects?

When I use keras to load the model, I am using

    model = tf.keras.models.load_model(mask_model_file,
                                       custom_objects={'bce_dice_loss': bce_dice_loss, 'dice_loss': dice_loss})

as mentioned here https://github.com/keras-team/keras/issues/3977

KerasImageFileTransformer doesn't support custom object loading. I am trying to run the below code which is failing.

mask_transformer = KerasImageFileTransformer(inputCol='uri', outputCol='mask', modelFile=mask_model_file, imageLoader=load_preprocess_mask_img, outputMode='vector')
masks = mask_transformer.transform(uri_df)

The stack trace for failure is:

TypeError                                 Traceback (most recent call last)
<ipython-input-51-8fe840872c2e> in <module>()
----> 1 masks = mask_transformer.transform(uri_df)

/opt/spark-2.3.2/python/pyspark/ml/base.py in transform(self, dataset, params)
    171                 return self.copy(params)._transform(dataset)
    172             else:
--> 173                 return self._transform(dataset)
    174         else:
    175             raise ValueError("Params must be a param map but got %s." % type(params))

/private/var/folders/b5/9rq_y2gx4sz5k92cgzmcfz95cn72xb/T/spark-857b86db-c3f7-4376-a2cf-7b6c8c40ac74/userFiles-8ee3f002-5bbf-44f0-8897-3bda2c93b6e7/databricks_spark-deep-learning-1.2.0-spark2.3-s_2.11.jar/sparkdl/transformers/keras_image.py in _transform(self, dataset)
     60         with KSessionWrap() as (sess, keras_graph):
     61             graph, inputTensorName, outputTensorName = self._loadTFGraph(sess=sess,
---> 62                                                                          graph=keras_graph)
     63             image_df = self.loadImagesInternal(dataset, self.getInputCol())
     64             transformer = TFImageTransformer(channelOrder='RGB', inputCol=self._loadedImageCol(),

/private/var/folders/b5/9rq_y2gx4sz5k92cgzmcfz95cn72xb/T/spark-857b86db-c3f7-4376-a2cf-7b6c8c40ac74/userFiles-8ee3f002-5bbf-44f0-8897-3bda2c93b6e7/databricks_spark-deep-learning-1.2.0-spark2.3-s_2.11.jar/sparkdl/param/shared_params.py in _loadTFGraph(self, sess, graph)
    169         with graph.as_default():
    170             K.set_learning_phase(0)  # Inference phase
--> 171             model = load_model(self.getModelFile())
    172             out_op_name = tfx.op_name(model.output, graph)
    173             stripped_graph = tfx.strip_and_freeze_until([out_op_name], graph, sess,

/Users/vivek.vanga/anaconda3/lib/python3.6/site-packages/keras/engine/saving.py in load_model(filepath, custom_objects, compile)
    258             raise ValueError('No model found in config file.')
    259         model_config = json.loads(model_config.decode('utf-8'))
--> 260         model = model_from_config(model_config, custom_objects=custom_objects)
    261 
    262         # set weights

/Users/vivek.vanga/anaconda3/lib/python3.6/site-packages/keras/engine/saving.py in model_from_config(config, custom_objects)
    332                         '`Sequential.from_config(config)`?')
    333     from ..layers import deserialize
--> 334     return deserialize(config, custom_objects=custom_objects)
    335 
    336 

/Users/vivek.vanga/anaconda3/lib/python3.6/site-packages/keras/layers/__init__.py in deserialize(config, custom_objects)
     53                                     module_objects=globs,
     54                                     custom_objects=custom_objects,
---> 55                                     printable_module_name='layer')

/Users/vivek.vanga/anaconda3/lib/python3.6/site-packages/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    143                     config['config'],
    144                     custom_objects=dict(list(_GLOBAL_CUSTOM_OBJECTS.items()) +
--> 145                                         list(custom_objects.items())))
    146             with CustomObjectScope(custom_objects):
    147                 return cls.from_config(config['config'])

/Users/vivek.vanga/anaconda3/lib/python3.6/site-packages/keras/engine/network.py in from_config(cls, config, custom_objects)
   1025                 if layer in unprocessed_nodes:
   1026                     for node_data in unprocessed_nodes.pop(layer):
-> 1027                         process_node(layer, node_data)
   1028 
   1029         name = config.get('name')

/Users/vivek.vanga/anaconda3/lib/python3.6/site-packages/keras/engine/network.py in process_node(layer, node_data)
    984             # and building the layer if needed.
    985             if input_tensors:
--> 986                 layer(unpack_singleton(input_tensors), **kwargs)
    987 
    988         def process_layer(layer_data):

/Users/vivek.vanga/anaconda3/lib/python3.6/site-packages/keras/engine/base_layer.py in __call__(self, inputs, **kwargs)
    429                                          'You can build it manually via: '
    430                                          '`layer.build(batch_input_shape)`')
--> 431                 self.build(unpack_singleton(input_shapes))
    432                 self.built = True
    433 

/Users/vivek.vanga/anaconda3/lib/python3.6/site-packages/keras/layers/normalization.py in build(self, input_shape)
     90 
     91     def build(self, input_shape):
---> 92         dim = input_shape[self.axis]
     93         if dim is None:
     94             raise ValueError('Axis ' + str(self.axis) + ' of '

TypeError: tuple indices must be integers or slices, not list
MrBanhBao commented 5 years ago

@vvivek921 I currently have the same problem. Did you find a way to load custom objects?

ghunkins commented 5 years ago

The KerasImageFileTransformer is meant to be used with images. Try the KerasTransformer class instead.

Example:

from sparkdl import KerasTransformer
from keras.models import Sequential
from keras.layers import Dense
import numpy as np

# Generate random input data
num_features = 10
num_examples = 100
input_data = [{"features" : np.random.randn(num_features).tolist()} for i in range(num_examples)]
input_df = sqlContext.createDataFrame(input_data)

# Create and save a single-hidden-layer Keras model for binary classification
# NOTE: In a typical workflow, we'd train the model before exporting it to disk,
# but we skip that step here for brevity
model = Sequential()
model.add(Dense(units=20, input_shape=[num_features], activation='relu'))
model.add(Dense(units=1, activation='sigmoid'))
model_path = "/tmp/simple-binary-classification"
model.save(model_path)

# Create transformer and apply it to our input data
transformer = KerasTransformer(inputCol="features", outputCol="predictions", modelFile=model_path)
final_df = transformer.transform(input_df)

Source: https://github.com/databricks/spark-deep-learning#working-with-images-in-spark

MrBanhBao commented 5 years ago

@ghunkins yep, I know. My problem is that for instance MobileNet is using custom objects like ReLU(6.) and DepthwiseConv2D, which can be loaded in Keras within a CustomObjectScope. However sparkdl's KerasImageFileTransformer seem not to have a parameter to load those custom objects.

vvivek921 commented 5 years ago

@MrBanhBao You can edit the load model in KerasImageFileTransformer to accept a parameter to load custom objects and rebuild and use spark dl. I however haven't tried the above. I ended up not using spark DL.