snorkel-team / snorkel

A system for quickly generating training data with weak supervision
https://snorkel.org
Apache License 2.0
5.81k stars 857 forks source link

Choose a memoization key in `preprocessor(memoize=True)` #1561

Closed Wirg closed 4 years ago

Wirg commented 4 years ago

Is your feature request related to a problem? Please describe.

Currently, there is no way to decide the key to be memoized when using preprocessor(memoize=True). This leads to 2 issues :

Describe the solution you'd like

Being able to parametrize it in the decorator.

@preprocessor(memoize=True, memoize_key=lambda p: p.base_website_url)
def add_website_reliability(paragraph):
    paragraph.website_reliability = evaluate_reliability(paragraph.base_website_url)
    return paragraph

Additional context

Current workaroud.

@mock.patch("snorkel.map.core.get_hashable", lambda p: p.base_website_url)
@preprocessor(memoize=True)
def add_website_reliability(paragraph):
    paragraph.website_reliability = evaluate_reliability(paragraph.base_website_url)
    return paragraph
henryre commented 4 years ago

Hi @Wirg, thanks for posting! Sounds like you have a really interesting use case here, would love to hear about it if you want to share some detail! The solution you proposed makes sense to me. I'd be happy to help out and review if you want to submit a PR. Looks like you'd need to make the change in get_hashable https://github.com/snorkel-team/snorkel/blob/dd240003857ee4e95fa49599d708c5ba878b3939/snorkel/map/core.py#L41 and then make the keyword argument accessible up the stack.

Wirg commented 4 years ago

Hi @henryre and thank you for the answer,

I would have considered using a memoize_key like argument with None defaulting to the current get_hashable. Namely, something like :

class BaseMapper:
    def __init__(self, name: str, pre: List["BaseMapper"], memoize: bool, memoize_key: Callable=None) -> None:
        if memoize_key is None:
                memoize_key = get_hashable
        self.name = name
        self._pre = pre
        self.memoize = memoize
        self._memoize_key = memoize_key 
        self.reset_cache()

    def __call__(self, x: DataPoint) -> Optional[DataPoint]:
        if self.memoize:
            x_hashable = self._memoize_key(x)
            if x_hashable in self._cache:
                return self._cache[x_hashable]
       ...

What would you think about this ?

Concerning the use case, it's not the one I am currently working on but I thought this one would be simpler to explain.

henryre commented 4 years ago

Hey @Wirg, that looks good to me! Just tag me on a PR once you submit!

Wirg commented 4 years ago

Hi @henryre , to pick up the discussion on the pr : https://github.com/snorkel-team/snorkel/pull/1572 .

I wanted to solve two issues with this change.

  • memoization can not be done for unhashable classes (typically a group of pandas rows). We need to wrap or subclass it.
  • memoization key can not be specific to a preprocessing. Example : We are trying to evaluate the reliability of a paragraph in a blog post. We could evaluate the reliability of the paragraph and of the website. The preprocessing corresponding to those 2 tasks will share the same key for memoize, which is not ideal : a website can have a few thousand paragraphs so we will evaluate website reliability a lot more than necessary.

Memoization can not be done for unhashable classes

The first one is solved : we can introduce a memoize_key to hash any object we want. There is one red flag tho : the memoize_key has to be "really smart" and should be able to memoize/hash the data coming from any preprocessor results binded to the DataPoint else in some case we will have dangerous caching that will create issues when the preprocessor is used by different LF/preprocessor with preprocessor in a different order.

from snorkel.labeling import labeling_function, LFApplier
from snorkel.preprocess import preprocessor
from types import SimpleNamespace

data = [
    SimpleNamespace(text='this a check', number=5),
    SimpleNamespace(text='this is something else', number=5)
]

# the memoize key here does not take into account
# all the attributes defined by other preprocessors
@preprocessor(memoize=True, memoize_key=lambda d: (d.number, d.text))
def square(x):
    x.square = x.number ** 2
    return x

@preprocessor(memoize=True)
def text_contains_check(x):
    x.contains_check = "check" in x.text
    return x

If we define our LFS and applier :

@labeling_function(pre=[square])
def square_gt_20(x):
    return x.square > 20

@labeling_function(pre=[text_contains_check, square])
def square_gt_10_and_contains_check(x):
    return x.square > 10 and x.contains_check

# this fails because  `x.contains_check` is not defined inside `square_gt_10_and_contains_check`
# because it has been erased by the cache applied to square
applier = LFApplier([square_gt_20, square_gt_10_and_contains_check])
predictions = applier.apply(data)

# this succeeds but `square_gt_20` was mutated by `text_contains_check`
# which could have a silent side effect
applier = LFApplier([square_gt_10_and_contains_check, square_gt_20])
predictions = applier.apply(data)

This does not appear with pre=[square, text_contains_check] in the other order. Relying on this with a big code base seems to be a potential red flag to me. I see snorkel LFs as a way to design bigger and smarter expert systems by reducing the headache of combining multiple expert functions. This seems to me as this change does not fit this role as is : one will have to keep in mind how does every preprocessor caches, and how to order them.

Memoization key can not be specific to a preprocessing

My goal here was to be able to compute only once some preprocessing that might be quite expensive but replicated between a bunch of DataPoints. In the previous example, if square was really expensive (squaring a huge matrix ?), it would be nice if it was only applied to each new matrix to square and not each new DataPoint with its previous preprocessings. In this situation, the memoize_key of square would be lambda x: x.number and if we do so we will override any preprocessing not specific to the number attribute and also any other attributes.

Current workaround

I think the issue is coming from the fact that snorkel.preprocess.preprocessor is expected to both produce the value from the DataPoint and assign it back.

To work around this, I created this decorator which transform a producer that does not do the assignment to a preprocessor.

from functools import wraps
from typing import Callable, Mapping, Optional, List, Any, Hashable

from cachetools import cached
from snorkel.map.core import get_hashable, BaseMapper
from snorkel.preprocess import preprocessor
from snorkel.types import DataPoint

Preprocessor = Callable[[DataPoint], DataPoint]
Producer = Callable[[DataPoint], Any]
HashingFunction = Callable[[Any], Hashable]

def producer_to_snorkel_preprocessor(
        name: Optional[str] = None,
        attribute_key: Optional[str] = None,
        pre: Optional[List[BaseMapper]] = None,
        memoize: bool = False,
        memoize_key: Optional[HashingFunction] = None,
        cache: Mapping = None,
) -> Callable[[Producer], Preprocessor]:
    memoize_key = get_hashable if memoize_key is None else memoize_key

    def _decorator(producer: Producer) -> Preprocessor:
        if memoize:
            _producer = cached({} if cache is None else cache, key=memoize_key)(producer)
        else:
            _producer = producer
        _attribute_key = producer.__name__ if attribute_key is None else attribute_key
        @wraps(producer)
        def _producer_as_preprocessor(x: DataPoint):
            attribute_to_add = _producer(x)
            if hasattr(x, '__setitem__'):
                x[_attribute_key] = attribute_to_add
            else:
                setattr(x, _attribute_key, attribute_to_add)
            return x
        return preprocessor(name=name, pre=pre)(_producer_as_preprocessor)
    return _decorator

And I use it this way.

from snorkel.labeling import labeling_function, LFApplier
from snorkel.preprocess import preprocessor
from types import SimpleNamespace
import numpy as np

@producer_to_snorkel_preprocessor(memoize=True, memoize_key=lambda x: x.number)
def square(x):
    return x.number ** 2

@producer_to_snorkel_preprocessor(memoize=True, memoize_key=lambda x: x.text)
def contains_check(x):
    return "check" in x.text

@labeling_function(pre=[square])
def square_gt_20(x):
    return x.square > 20

@labeling_function(pre=[contains_check, square])
def square_gt_10_and_contains_check(x):
    return x.square > 10 and x.contains_check

data = [SimpleNamespace(text='this a check', number=5), SimpleNamespace(text='this is something else', number=5)]

applier = LFApplier([square_gt_20, square_gt_10_and_contains_check])

np.testing.assert_array_equal(applier.apply(data), applier.apply(data[::-1])[::-1])

@henryre Do you think there is any way to make this cleaner ? I feel like I am a bit lost between the Mapper, Preprocessor and their snake_case decorators.

henryre commented 4 years ago

Hi @Wirg thanks for the extra context here! I think we're pretty close to a solution with the PR you put up.

To your first point: cache key selection is always a tricky problem, and choosing a custom key is an advanced user operation. So I think it's fair to assume that anyone using this option is expected to understand how the preprocessor memoization works. A common approach here would be to pre-generate a random, unique ID for each data point.

To your second point about specific preprocessing: Snorkel is currently set up to cache preprocessing at the data point-level, not the field level. In many (not all) cases, however, I believe you should be able to accomplish what you're looking for using nested preprocessors. The preprocessor decorator (or the Preprocessor class initializer) can take in a pre=[...] argument itself with a list of other Preprocessors. This means that you can choose to memoize at the outer layer only. So taking a look at your example:

from snorkel.labeling import labeling_function, LFApplier
from snorkel.preprocess import preprocessor
from types import SimpleNamespace
import numpy as np

@preprocessor(memoize=False)
def square(x):
     x.square = x.number ** 2
     return x

@preprocessor(memoize=True, pre=[square], memoize_key=lambda x: (x.text, x.number))
def contains_check_and_sqaure(x):
    x.contains_check = "check" in x.text
    return x

@labeling_function(pre=[square])
def square_gt_20(x):
    return x.square > 20

@labeling_function(pre=[contains_check_and_sqaure])
def square_gt_10_and_contains_check(x):
    return x.square > 10 and x.contains_check

If there are multiple expensive per-field operations, then the one-field-per-preprocessor approach likely won't work here. You could, however, put those in the same preprocessor.

@preprocessor(memoize=True, memoize_key=lambda x: (x.text, x.number))
def contains_check_and_sqaure(x):
    x.square = x.number ** 2
    x.contains_check = "check" in x.text
    return x
Wirg commented 4 years ago

I can only agree with the complexity of setting up a cache.

I guess the feature as is solve solves some of the issues I was aiming for.

For the rest, I currently have a workaround I mentioned above and you mentioned two other ways to achieve this.

I am waiting for your review (https://github.com/snorkel-team/snorkel/pull/1572).