Bihaqo / t3f

Tensor Train decomposition on TensorFlow
https://t3f.readthedocs.io/en/latest/index.html
MIT License
218 stars 56 forks source link

get_config function #212

Open jyan26 opened 3 years ago

jyan26 commented 3 years ago

In t3f nn.py, I was wondering if we could add get_config function to KerasDense class so that we could save the trained model? Right now, without get_config, I got this error message when trying to save the model or using callbacks in model.fit: "NotImplementedError: Layers with arguments in __init__ must override get_config." (I'm using TF 2.0.3.) So maybe we could add something like (I'm not sure if the following is right):

def get_config(self):
  config = super(KerasDense,self).get_config()
  config.update({'input_dims': self.input_dims,
                 'output_dims': self.output_dims,
                 'tt_rank': self.tt_rank,
                 'activation': self.activation,
                 'use_bias': self.use_bias,
                 'kernel_initializer':self.kernel_initializer,
                 'bias_initializer': self.bias_initializer})
  return config
Bihaqo commented 3 years ago

Great suggestion, thank you! Are you ok with me copy pasting this into the library? Or feel free to do it yourself and send a pull request if you want.

On Mon, 14 Dec 2020 at 09:08, jyan26 notifications@github.com wrote:

In t3f nn.py, I was wondering if we could add get_config function to KerasDense class so that we could save the trained model? Right now, without get_config, I got this error message when trying to save the model or using callbacks in model.fit: "NotImplementedError: Layers with arguments in init must override get_config." (I'm using TF 2.0.3.) So maybe we could add something like (I'm not sure if the following is right):

def get_config(self): config = super(KerasDense,self).get_config() config.update({'input_dims': self.input_dims, 'output_dims': self.output_dims, 'tt_rank': self.tt_rank, 'activation': self.activation, 'use_bias': self.use_bias, 'kernel_initializer':self.kernel_initializer, 'bias_initializer': self.bias_initializer}) return config

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/Bihaqo/t3f/issues/212, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABK6V2I7HEOYCHY2LGEZIU3SUXI2HANCNFSM4U2NLPAA .

jyan26 commented 3 years ago

Yes, sure, please feel free to copy and paste this code. I didn't have a chance to test this code though, so please test it along with other codes before adding it into the library. Thanks!