ZPZhou-lab / tfkan

The tensorflow implementation of KANs
MIT License
38 stars 8 forks source link

add get_config and from_config methods #6

Open jalalmzh opened 4 months ago

jalalmzh commented 4 months ago

if you want to save and load models that contain this layers you shoud add get_config and from_config methods to custom layers like this

   def get_config(self):
            config = super(DenseKAN, self).get_config()
            config.update({'units':self.units ,
                                'grid_size':self.grid_size,
                                'spline_order':self.spline_order,
                                'grid_range':self.grid_range,
                                'basis_activation':self.basis_activation,
                                'use_bias':self.use_bias,
                                'spline_initialize_stddev':self.spline_initialize_stddev})       
           return config    
    @classmethod
     def from_config(cls, config):
        return cls(**config)

and register them by "tf.keras.utils.get_custom_objects()['DenseKAN'] = DenseKAN"

ZPZhou-lab commented 4 months ago

thasks for feedback, I will add these two methods into each KAN layer.🫡