albermax / innvestigate

A toolbox to iNNvestigate neural networks' predictions!
Other
1.26k stars 233 forks source link

Added LRP-analyzer class to allow application of LRP_flat on all (parametrized) layers until a given index #169

Closed leanderweber closed 4 years ago

enryH commented 4 years ago

I guess both kwargs are possible together, so we keep both?

LRP class __init__:

class LRP(base.ReverseAnalyzerBase):
    """
    Base class for LRP-based model analyzers

    :param model: A Keras model.

    :param rule: A rule can be a  string or a Rule object, lists thereof or a list of conditions [(Condition, Rule), ... ]
      gradient.

    :param input_layer_rule: either a Rule object, atuple of (low, high) the min/max pixel values of the inputs
    :param bn_layer_rule: either a Rule object or None.
      None means dedicated BN rule will be applied.
    """

    def __init__(self, model, *args, **kwargs):
        rule = kwargs.pop("rule", None)
        input_layer_rule = kwargs.pop("input_layer_rule", None)
<<<<<<< master
        until_layer_idx = kwargs.pop("until_layer_idx", None)
        until_layer_rule = kwargs.pop("until_layer_rule", None)
=======
        bn_layer_rule = kwargs.pop("bn_layer_rule", None)
        bn_layer_fuse_mode = kwargs.pop("bn_layer_fuse_mode", "one_linear")
        assert bn_layer_fuse_mode in ["one_linear", "two_linear"]
>>>>>>> develop

        self._add_model_softmax_check()
        self._add_model_check(
            lambda layer: not kchecks.is_convnet_layer(layer),
            "LRP is only tested for convolutional neural networks.",
            check_type="warning",
        )

        # check if rule was given explicitly.
        # rule can be a string, a list (of strings) or a list of conditions [(Condition, Rule), ... ] for each layer.
        if rule is None:
            raise ValueError("Need LRP rule(s).")

        if isinstance(rule, list):
            # copy refrences
            self._rule = list(rule)
        else:
            self._rule = rule
        self._input_layer_rule = input_layer_rule
<<<<<<< master
        self._until_layer_rule = until_layer_rule
        self._until_layer_idx = until_layer_idx

=======
        self._bn_layer_rule = bn_layer_rule
        self._bn_layer_fuse_mode = bn_layer_fuse_mode
>>>>>>> develop

        if(
           isinstance(rule, six.string_types) or
           (inspect.isclass(rule) and issubclass(rule, kgraph.ReverseMappingBase)) # NOTE: All LRP rules inherit from kgraph.ReverseMappingBase
        ):
            # the given rule is a single string or single rule implementing cla ss
            use_conditions = True
            rules = [(lambda a, b: True, rule)]

        elif not isinstance(rule[0], tuple):
            # rule list of rule strings or classes
            use_conditions = False
            rules = list(rule)
        else:
            # rule is list of conditioned rules
            use_conditions = True
            rules = rule

        #apply rule to first self._until_layer_idx layers
        if self._until_layer_rule is not None and self._until_layer_idx is not None:
            for i in range(self._until_layer_idx+1):
                rules.insert(0,
                             (lambda layer, foo, bound_i=i: kchecks.is_layer_at_idx(layer, bound_i),
                              self._until_layer_rule))

        # create a BoundedRule for input layer handling from given tuple
        if self._input_layer_rule is not None:
            input_layer_rule = self._input_layer_rule
            if isinstance(input_layer_rule, tuple):
                low, high = input_layer_rule

                class BoundedProxyRule(rrule.BoundedRule):
                    def __init__(self, *args, **kwargs):
                        super(BoundedProxyRule, self).__init__(
                            *args, low=low, high=high, **kwargs)
                input_layer_rule = BoundedProxyRule

            if use_conditions is True:
                rules.insert(0,
                             (lambda layer, foo: kchecks.is_input_layer(layer),
                              input_layer_rule))

            else:
                rules.insert(0, input_layer_rule)

        self._rules_use_conditions = use_conditions
        self._rules = rules

        # FINALIZED constructor.
        super(LRP, self).__init__(model, *args, **kwargs)
enryH commented 4 years ago

Regarding the many ToDos: Should we move them to method docstrings so they can be more easily stumbled upon by new users?