deepgram / kur

Descriptive Deep Learning
Apache License 2.0
814 stars 107 forks source link

plot_weights hook: plot unlimited weights; access weights directly from Executor.model not external files #68

Closed EmbraceLife closed 7 years ago

EmbraceLife commented 7 years ago

Now, I have managed to plot the convolutional weights by the following kurfile:

hooks:
    - plot_weights:
        weight_file: cifar.best.valid.w
        weight_keywords1: ["convolution.0", "kernel"]
        weight_keywords2: ["convolution.1", "kernel"]

In my plot_weights_hook.py, I get weight_keywords1, weight_keywords2 into hooks through __init__():

def __init__(self, weight_file, weight_keywords1, weight_keywords2, *args, **kwargs):
        """ Creates a new plotting hook, get plot filenames and matplotlib ready.
        """

my question:

if I want to plot more convolutional weights, say weight_keywords3, weight_keywords4, weight_keywords5, do I have to change the source code, by adding them into __init__ like above?

Can **kwargs somehow help me avoid changing source every time I want to plot more weights? If so, how?

Thanks!

ajsyp commented 7 years ago

You can do it two ways. First, yes, you could use **kwargs to do it, but then you need to be careful to make sure that you pass through the correct pieces of kwargs to the base class constructor (there are lots of ways you can envision this, but all are harder to maintain or are fragile). A better way to do it would be to simply add another layer of indirection:

hooks:
  - plot_weights
      weight_file: cifar.best.valid.w
      with_weights:
        - ['convolution.0', 'kernel']
        - ['convolution.1', 'kernel']
        - ['convolution.2', 'kernel']
        - ...

And then your constructor signature looks like this: def init(self, weight_file, with_weights, *args, **kwargs)

ajsyp commented 7 years ago

Also, it seems unfortunate that you need the weight_file. What if I were to pass the model instance into each TrainingHook and then you could just grab the latest model weights right there? Maybe still support weight_file if the user wants to manually choose a weight file. This has the benefit of always getting the newest model and not having to repeat oneself by specifying the weight file, but it will probably require a little thought about handling different model backends.

EmbraceLife commented 7 years ago

What if I were to pass the model instance into each TrainingHook and then you could just grab the latest model weights right there?

I don't know how to do it at the moment, could you write a lines of code to give me some hint?

This has the benefit of always getting the newest model and not having to repeat oneself by specifying the weight file, but it will probably require a little thought about handling different model backends

I agree that to keep weight_file gives users more choices on weight files to use. But I don't understand the part below, because having user to manually choose a file is in conflict with the quote below right?

not having to repeat oneself by specifying the weight file, but it will probably require a little thought about handling different model backends

ajsyp commented 7 years ago

I was suggesting having the option of leaving weight_file empty/absent. That way, if users specify a weight_file, then the hook uses those weights, but if they don't specify a weight_file, the hook just uses the current model weights at the time the hook was invoked.

Of course, this would require the model to be made accessible to the training hook (evaluation hooks already are given the model). We'd have to tweak the API slightly to support it, from this:

TrainingHook.notify(self, status, log=None, info=None)

to this:

TrainingHook.notify(self, status, log=None, info=None, model=None)
EmbraceLife commented 7 years ago

To access the latest model weight at the end of an epoch training:

I shall insert the following code into plot_weights_hook.py, so when weight_file is empty, idx.load is not used, but use the following code to get weights arrays:

                # self == Executor object or trainer object 
        layers = self.model.flattened_layers \
            if hasattr(self.model, 'flattened_layers') else self.model.layers
        for layer in layers:
            layer_name = layer.name

            symbolic_weights = layer.weights
            weight_names, weight_values = \
                self._get_weight_names_and_values_from_symbolic(
                    symbolic_weights
                )

                        # use weight_values (arrays) to plot weights
ajsyp commented 7 years ago

You could do something like that, but that would only work with Keras model, not a PyTorch one. A better idea would be to create a temporary file (use Python's tempfile module) and save the weights there using the public API (Model.save()). You then have them in a standardized form on disk.

EmbraceLife commented 7 years ago

Sound good, but what is this the public API (Model.save())? Where can I find it?

All I can find is kur.model.Model.save but this eventually leads to keras_backend._save_keras.

ajsyp commented 7 years ago

That's it: kur.model.Model.save(). But it only calls keras_backend._save_keras() IF the Keras backend is selected. If the PyTorch backend is in use, then it will use pytorch_backend.save() instead. That's the reason for using inheritance (base/derived classes), and the whole point of designing an API.

EmbraceLife commented 7 years ago

I will continue to work on this tomorrow. Thanks a lot for your help!

EmbraceLife commented 7 years ago

The workflow I am trying Here is how to create a tempfolder, inside Executor.wrapped_train before executing plot_weights_hook at the end of an epoch:

weight_path = None
        tempdir = tempfile.mkdtemp()
        try:
            weight_path = os.path.join(tempdir, 'weights')

Here is how to save weights into tempdir weight_path:

kur.model.Model.save(weight_path)

Here is how to access individual files in the weight_path:

# borrow from keras_backend.enumerate_saved_tensors
for dirpath, dirnames, filenames in os.walk(weight_path): # pylint: disable=unused-variable
            for filename in filenames:
                match = regex.match(filename)
                if match is None:
                    continue

                set_trace()
                filename = os.path.join(dirpath, filename)

Then connect to plot_weights_hook.py to extract weights from individual weight files and plot

EmbraceLife commented 7 years ago

I have updated PR, all the features discussed above seem working so far.