keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.99k stars 19.48k forks source link

Can't set sprite in Keras Tensorboard Callback #15573

Closed jvishnuvardhan closed 1 year ago

jvishnuvardhan commented 3 years ago

This feature request is moved from TF repository to keras-team/keras https://github.com/tensorflow/tensorflow/issues/31050

Describe the feature and the current behavior/state.

There is no way to specify a sprite for embedding visualizations with the Tensorboard Callback in Keras.

Will this change the current api? How?

It would add two more parameters to Tensorboard.__init__()

Who will benefit with this feature?

Anyone who wants to be able to see what their embedding clusters actually look like. NLP people are covered tho.

Any Other info.

Here's basically what you need. I just don't have the time to write unittests and docs.


class Tensorboard(Callback):

    def __init__(self,
                   ...
                   embeddings_freq=0,
                   embeddings_layer_names=None,
                   embeddings_metadata=None,
                   embeddings_data=None,

                   ### add
                   embeddings_image_path=None,
                   embeddings_image_size=None,
                   ...):

        ### add 
        self.embeddings_image_path = embeddings_image_path
        self.embeddings_image_size = embeddings_image_size

    def set_model(self, model):
        ...

        ### add/rework a tad ~L282-L290
        # allow embedding parameters to be passed as a dict(layer_name -> param)
        # or as a single string that applies to all layers
        layer_names = embeddings_vars.keys()

        embeddings_metadata = (
            {name: self.embeddings_metadata for name in layer_names}
            if not isinstance(self.embeddings_metadata, dict) else data)

        embeddings_image_path = (
            {name: self.embeddings_image_path for name in layer_names}
            if not isinstance(self.embeddings_image_path, dict) else data)

        embeddings_image_size = (
            {name: self.embeddings_image_size for name in layer_names}
            if not isinstance(self.embeddings_image_size, dict) else data)
        ###

        try:
            from tensorboard.plugins import projector
        except ImportError:
            raise ImportError('Failed to import TensorBoard. Please make sure that '
                              'TensorBoard integration is complete."')

        # TODO(psv): Add integration tests to test embedding visualization
        # with TensorBoard callback. We are unable to write a unit test for this
        # because TensorBoard dependency assumes TensorFlow package is installed.
        config = projector.ProjectorConfig()
        for layer_name, tensor in embeddings_vars.items():
            embedding = config.embeddings.add()
            embedding.tensor_name = tensor.name

            if (embeddings_metadata is not None and
                layer_name in embeddings_metadata):
                embedding.metadata_path = embeddings_metadata[layer_name]

            ### add 
            if (embeddings_image_path is not None and
                embeddings_image_size is not None and
                layer_name in embeddings_image_path and
                layer_name in embeddings_image_size):

                # this is what I need
                embedding.sprite.image_path = embeddings_image_path[layer_name]
                embedding.sprite.single_image_dim.extend(embeddings_image_size)
            ###

        projector.visualize_embeddings(self.writer, config)
qlzh727 commented 3 years ago

Assign to @rchao who works on callbacks.

SuryanarayanaY commented 1 year ago

Hello, Thank you for reporting an issue.

We're currently in the process of migrating the new Keras 3 code base from keras-team/keras-core to keras-team/keras. Consequently, This issue may not be relevant to the Keras 3 code base . After the migration is successfully completed, feel free to reopen this Issue at keras-team/keras if you believe it remains relevant to the Keras 3 code base. If instead this Issue is a bug or security issue in legacy tf.keras, you can instead report a new issue at keras-team/tf-keras, which hosts the TensorFlow-only, legacy version of Keras.

To know more about Keras 3, please read https://keras.io/keras_core/announcement/