Closed cweill closed 4 years ago
class SkipWrapper: def __init__(self, wrappee): self._wrappee = wrappee self._resolver = tfx.ResolverNode( instance_name=f'skip_{wrappee.id}', resolver_class=latest_artifacts_resolver.LatestArtifactsResolver, **wrappee.outputs) @property def outputs(self) -> Dict[str, types.Channel]: return self._resolver.outputs @property def resolver(self) -> tfx.ResolverNode: return self._resolver def __getattr__(self, attr): return getattr(self._wrappee, attr) class Benchmark(abc.ABC): """A benchmark which can be composed of several benchmark methods. The Benchmark object design is inspired by `unittest.TestCase`. A benchmark file can contain multiple Benchmark subclasses to compose a suite of benchmarks. """ def __init__(self): self._benchmark = self # The sub-benchmark stack. self._result = None self._seen_benchmarks = None self._pipeline = [] self._within_skip_context = False @abc.abstractmethod def benchmark(self, **kwargs): """Benchmark method to be overridden by subclasses. Args: **kwargs: Keyword args that are propagated from the called to nitroml.run(...). """ def _run_component(self, component): if self._within_skip_context: logging.warning( 'Skipping "%s". Downstream components will use most recent artifacts.', component.id) skip_wrapper = SkipWrapper(component) self._pipeline.append(skip_wrapper.resolver) else: if hasattr(component, 'components'): for c in component.components: self._pipeline.append(c) else: self._pipeline.append(component) skip_wrapper = component return skip_wrapper # if self._within_skip_context: # self._pipeline.append(component) # skip_wrapper = SkipWrapper(component) # self._pipeline.append(skip_wrapper.resolver) # return skip_wrapper # self._pipeline.append(component) # return component def _rename_component(self, component): # pylint: disable=protected-access component._instance_name = _qualified_name(component._instance_name, self.id()) for _, ch in component.outputs.items(): ch.producer_component_id = component.id # pylint: enable=protected-access def run(self, component): if hasattr(component, 'components'): for c in component.components: self._rename_component(c) else: self._rename_component(component) return self._run_component(component) @contextlib.contextmanager def skip(self): old_within_skip_context = self._within_skip_context self._within_skip_context = True try: yield finally: self._within_skip_context = old_within_skip_context