google / nitroml

NitroML is a modular, portable, and scalable model-quality benchmarking framework for Machine Learning and Automated Machine Learning (AutoML) pipelines.
Apache License 2.0
41 stars 6 forks source link

Programmatic skipping #46

Closed cweill closed 4 years ago

cweill commented 4 years ago
class SkipWrapper:

  def __init__(self, wrappee):
    self._wrappee = wrappee
    self._resolver = tfx.ResolverNode(
        instance_name=f'skip_{wrappee.id}',
        resolver_class=latest_artifacts_resolver.LatestArtifactsResolver,
        **wrappee.outputs)

  @property
  def outputs(self) -> Dict[str, types.Channel]:
    return self._resolver.outputs

  @property
  def resolver(self) -> tfx.ResolverNode:
    return self._resolver

  def __getattr__(self, attr):
    return getattr(self._wrappee, attr)

class Benchmark(abc.ABC):
  """A benchmark which can be composed of several benchmark methods.

  The Benchmark object design is inspired by `unittest.TestCase`.

  A benchmark file can contain multiple Benchmark subclasses to compose a suite
  of benchmarks.
  """

  def __init__(self):
    self._benchmark = self  # The sub-benchmark stack.
    self._result = None
    self._seen_benchmarks = None
    self._pipeline = []
    self._within_skip_context = False

  @abc.abstractmethod
  def benchmark(self, **kwargs):
    """Benchmark method to be overridden by subclasses.

    Args:
      **kwargs: Keyword args that are propagated from the called to
        nitroml.run(...).
    """

  def _run_component(self, component):
    if self._within_skip_context:
      logging.warning(
          'Skipping "%s". Downstream components will use most recent artifacts.',
          component.id)
      skip_wrapper = SkipWrapper(component)
      self._pipeline.append(skip_wrapper.resolver)
    else:
      if hasattr(component, 'components'):
        for c in component.components:
          self._pipeline.append(c)
      else:
        self._pipeline.append(component)
      skip_wrapper = component
    return skip_wrapper

    # if self._within_skip_context:
    #   self._pipeline.append(component)
    #   skip_wrapper = SkipWrapper(component)
    #   self._pipeline.append(skip_wrapper.resolver)
    #   return skip_wrapper

    # self._pipeline.append(component)
    # return component

  def _rename_component(self, component):
    # pylint: disable=protected-access
    component._instance_name = _qualified_name(component._instance_name,
                                               self.id())
    for _, ch in component.outputs.items():
      ch.producer_component_id = component.id
    # pylint: enable=protected-access

  def run(self, component):
    if hasattr(component, 'components'):
      for c in component.components:
        self._rename_component(c)
    else:
      self._rename_component(component)
    return self._run_component(component)

  @contextlib.contextmanager
  def skip(self):
    old_within_skip_context = self._within_skip_context
    self._within_skip_context = True
    try:
      yield
    finally:
      self._within_skip_context = old_within_skip_context