psykei / psyki-python

PSyKI: a (Python) platform for symbolic knowledge injection
https://psykei.github.io/psyki-python/
Apache License 2.0
12 stars 1 forks source link

clone all old weights into the same layers #18

Closed github-actions[bot] closed 2 years ago

github-actions[bot] commented 2 years ago

https://github.com/psykei/psyki-python/blob/cd0c4b9a40f6c67d8d919eb22d536cef8d07958f/psyki/ski/kins/__init__.py#L72


from __future__ import annotations
from typing import Callable, Iterable, List
from tensorflow import Tensor
from tensorflow.keras import Model
from tensorflow.keras.layers import Concatenate
from tensorflow.python.keras.saving.save import load_model
from tensorflow.python.keras.utils.generic_utils import get_custom_objects
from psyki.logic.datalog.grammar import optimize_datalog_formula
from psyki.ski import Injector
from psyki.logic import Formula, Fuzzifier
from psyki.utils import eta, eta_one_abs, eta_abs_one, model_deep_copy

class NetworkStructurer(Injector):
    """
    This injector builds a set of moduls, aka ad-hoc layers, and inserts them into the predictor (a neural network).
    In this way the predictor can exploit the knowledge via these modules which mimic the logic formulae.
    With the default fuzzifier this is the implementation of KINS: Knowledge injection via network structuring.
    """

    custom_objects: dict[str: Callable] = {'eta': eta, 'eta_one_abs': eta_one_abs, 'eta_abs_one': eta_abs_one}

    def __init__(self, predictor: Model, feature_mapping: dict[str, int], fuzzifier: str, layer: int = 0):
        """
        @param predictor: the predictor.
        @param feature_mapping: a map between variables in the logic formulae and indices of dataset features. Example:
            - 'PL': 0,
            - 'PW': 1,
            - 'SL': 2,
            - 'SW': 3.
        @param layer: the level of the layer where to perform the injection.
        @param fuzzifier: the fuzzifier used to map the knowledge (by default it is SubNetworkBuilder).
        """
        self._predictor: Model = model_deep_copy(predictor)
        # self.feature_mapping: dict[str, int] = feature_mapping
        if layer < 0 or layer > len(predictor.layers) - 2:
            raise Exception('Cannot inject knowledge into layer ' + str(layer) +
                            '.\nYou can inject from layer 0 to ' + str(len(predictor.layers) - 2))
        self._layer = layer
        # Use as default fuzzifier SubNetworkBuilder.
        self._fuzzifier = Fuzzifier.get(fuzzifier)([self._predictor.input, feature_mapping])
        self._fuzzy_functions: Iterable[Callable] = ()

    def inject(self, rules: List[Formula]) -> Model:
        self._clear()
        # Prevent side effect on the original rules during optimization.
        rules_copy = [rule.copy() for rule in rules]
        for rule in rules_copy:
            optimize_datalog_formula(rule)
        predictor_input: Tensor = self._predictor.input
        modules = self._fuzzifier.visit(rules_copy)
        if self._layer == 0:
            # Injection!
            x = Concatenate(axis=1)([predictor_input] + modules)
            self._predictor.layers[1].build(x.shape)
            for layer in self._predictor.layers[1:]:
                x = layer(x)
        else:
            x = self._predictor.layers[1](predictor_input)
            for layer in self._predictor.layers[2:self._layer + 1]:
                x = layer(x)
            # Injection!
            x = Concatenate(axis=1)([x] + modules)
            self._predictor.layers[self._layer + 1].build(x.shape)
            x = self._predictor.layers[self._layer + 1](x)
            for layer in self._predictor.layers[self._layer + 2:]:
                # Correct shape if needed (e.g., dropout layers)
                if layer.input_shape != x.shape:
                    layer.build(x.shape)
                x = layer(x)
        new_predictor = Model(predictor_input, x)
        # TODO: clone all old weights into the same layers

        get_custom_objects().update(self.custom_objects)
        return new_predictor

    @staticmethod
    def load(file: str):
        return load_model(file, custom_objects=NetworkStructurer.custom_objects)

    def _clear(self):
        self._fuzzy_functions = ()