christophstoeckl / FS-neurons

47 stars 16 forks source link

Minimal working sample #12

Closed mervess closed 3 years ago

mervess commented 3 years ago

Hello @christophstoeckl , I tried the code with a few different toy network models (including _trivialmodel under _tfmodels), but none seem to be working. I'm expecting it to at least print out the number of spikes and neurons as I changed these params to True in the code. I have a few questions regarding it:

christophstoeckl commented 3 years ago

Hi, it seems that trivial_model does not use an activation function, hence FS-conversion would not do anything. Here is a slightly modified version which works:

import tensorflow as tf
from fs_coding import *
from tensorflow.python.keras import backend
from tensorflow.python.keras import models

replace_relu_with_fs()

def trivial_model(num_classes):
  """Trivial model for ImageNet dataset."""

  input_shape = (224, 224, 3)
  img_input = layers.Input(shape=input_shape)

  x = layers.Lambda(lambda x: backend.reshape(x, [-1, 224 * 224 * 3]),
                    name='reshape')(img_input)
  x = layers.Dense(100, name='fc1')(x)
  x = layers.Activation('relu')(x)
  x = layers.Dense(num_classes, name='fc1000')(x)
  x = layers.Activation('softmax')(x)

  return models.Model(img_input, x, name='trivial')

model = trivial_model(10)

# run the model on some input
model(tf.ones(shape=(10, 224, 224, 3)))

With the code above printing the number of neurons should work if the corresponding flag has been set to True in fs_coding.py. Printing the number of spikes only works with tf 1.14, as it is using the now depricated tf.Print function.

mervess commented 3 years ago

Thanks for the answers and the example.

The above example prints Number of neurons: 100 out, yet crashes in my workspace (TF-1.14.0, Python 3.7).

tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'input_1' with dtype float and shape [?,224,224,3]
         [[{{node input_1}}]]

and the source seems to be x = layers.Dense(num_classes, name='fc1000')(x). Did it work fine for you?

When from tensorflow.python.keras import layers is imported, then there is no error, but no info printing either. Overall, layers in the _fscoding.py should be in use to affect, correct?


I think both should work, as long as you use layers.Activation('relu') or tf.nn.relu (in case you want to convert relu neurons).

I've also noticed embedded activation functions do not work, such as activation=tf.nn.relu, they need to declared separately in the model.

Sequential Keras models do not seem to comply with the code unfortunately, throw an error at the last moment:

TypeError: The added layer must be an instance of class Layer. Found: <function fs_relu at 0x7fc37fc3cdd0>

the source of the above error is layers.Activation( 'softmax' ).

Note, that converting compiled models will probably not work.

After calling relative fs functions, the model then should be compiled/built or trained/tested. This is how it works, right?

christophstoeckl commented 3 years ago

I have tested this example with TF-2.5 and I believe if you want to get it to work with TF-1.14 you would have to create a input placeholder and launch the model inside a Session, as eager execution is not the default in TF before version 2.0.

Yes. The idea behind the replacement is really just to overwrite the TF/Keras activation functions with the corresponding FS function. This was aimed to make it easier to convert an existing ANN without having to rewrite many lines of code. Therefore the activation function in the ANN must be explicitly implemented, so that it can be replaced. (hence embedded activation functions are problematic.)

That's right, then sequential Keras will not work out of the box.

The workflow usually goes like this:

  1. First you train an ANN without using FS-coding.
  2. You save the weights of the trained ANN.
  3. You then activate FS-coding and restore the previously obtained weights.

It is probably not a good idea to train a model directly when it uses FS-coding.