dice-group / Ontolearn

Ontolearn is an open-source software library for explainable structured machine learning in Python. It learns OWL class expressions from positive and negative examples.
https://ontolearn-docs-dice-group.netlify.app/index.html
MIT License
43 stars 9 forks source link

LRU prediction caching neural reasoner #489

Closed LckyLke closed 2 weeks ago

LckyLke commented 2 weeks ago

We can use a LRU caching wrapper from functools - to make the cache size dynamic and based on the current object we have to create the function using the wrapper on runtime. This could look as follows:

    def __init__(self, path_of_kb: str = None, path_neural_embedding: str = None, gamma: float = 0.25, max_cache_size: int = 2**20):
        [...]
        self.predict = self._create_cached_predict()
 def _create_cached_predict(self):
        self_ref = self
        @lru_cache(maxsize=self._max_cache_size)
        def predict(h: str = None, r: str = None, t: str = None) -> List[Tuple[str,float]]:
            # sanity check
            assert h is not None or r is not None or t is not None, "At least one of h, r, or t must be provided."
            assert h is None or isinstance(h, str), "Head entity must be a string."
            assert r is None or isinstance(r, str), "Relation must be a string."
            assert t is None or isinstance(t, str), "Tail entity must be a string."

            if h is not None:
                if h not in self_ref.model.entity_to_idx:
                    # raise KeyError(f"Head entity '{h}' not found in model entity indices.")
                    return []
                h = [h]

            if r is not None:
                if r not in self_ref.model.relation_to_idx:
                    #raise KeyError(f"Relation '{r}' not found in model relation indices.")
                    return []
                r = [r]

            if t is not None:
                if t not in self_ref.model.entity_to_idx:
                    # raise KeyError(f"Tail entity '{t}' not found in model entity indices.")
                    return []
                t = [t]

            if r is None:
                topk = len(self_ref.model.relation_to_idx)
            else:
                topk = len(self_ref.model.entity_to_idx)

            return [ (top_entity, score)  for top_entity, score in self_ref.model.predict_topk(h=h, r=r, t=t, topk=topk) if score >= self_ref.gamma and is_valid_entity(top_entity)]
        return predict

Runtimes get better (e.g. family dataset):

with caching: 5856/5856 [00:10<00:00, 540.20it/s] without: 5856/5856 [03:23<00:00, 28.83it/s]

Demirrr commented 2 weeks ago

Thank you. How can we disable it if we want ? Maybe as shown below ?

if enable_caching:
     self.predict = self._create_cached_predict()
LckyLke commented 2 weeks ago

Yes, I was thinking of conditionally applying the wrapper like so:


        if self._max_cache_size:
            predict = lru_cache(maxsize=self._max_cache_size)(predict)

        self.predict = predict

And remove it from the top of the method ofc

Demirrr commented 2 weeks ago

You forgot the else condition, didn't you ?

LckyLke commented 2 weeks ago

No if caching is enabled the predict gets overwritten - otherwise it is just the function - but I wrote this at my phone rn so I haven't tested it xd

LckyLke commented 2 weeks ago
from functools import lru_cache
from typing import List, Tuple

    def __init__(self, max_cache_size=None):
        self._max_cache_size = max_cache_size
        self.predict = self._create_predict_method()

    def _create_predict_method(self):
        self_ref = self  

        def predict(h: str = None, r: str = None, t: str = None) -> List[Tuple[str, float]]:
            pass

        # Apply caching if max_cache_size is not zero or None
        if self._max_cache_size:
            predict = lru_cache(maxsize=self._max_cache_size)(predict)

        return predict

so like this

Demirrr commented 2 weeks ago

I don't think that we would needing _create_predict_method. What do you think about the following code ?

    def __init__(self, max_cache_size:int=None):
        self._max_cache_size = max_cache_size
        if isinstance(max_cache_size,int) and max_cache_size>0:
           self.predict=lru_cache(maxsize=max_cache_size)(self.predict)
        else:
               # don't do anything since self.predict is already defined 
LckyLke commented 2 weeks ago

Yes, this should also work - should I implement it and create a PR?

Demirrr commented 2 weeks ago

Yes please