zfit / zfit-development

The developement repository for zfit with roadmaps, internal docs etc to clean up the issues
0 stars 2 forks source link

Caching TF 2.0 - API and Specs #47

Open jonas-eschle opened 5 years ago

jonas-eschle commented 5 years ago

Requirements

We need a cache, that preserves the computed values. Next to technical difficulties with tracking dependencies, the invalidation behavior should be well defined.

Invalidation behavior (WIP)

If we find a consistent behavior, we can ask upstreams if an implementation would be considered, but it may be a little tricky.

A Cache is invalidated if:

Technical challenges

Since we cannot have feed_dicts anymore in TF 2.0, we need an alternative way of caching. The following is fully working and performant within graph as well as eager execution.

Caching the actual value

create a class CachedTensor that handles all the dependencies and invalidates the cache. The caching logic is done as follows:

Two Variables are created for each cached value: one that holds the cache, one that holds a boolean flag indicating the validity of the corresponding cache. This is used in conjunction with a tf.cond: tf.cond(flag, cache, cache_func) whereas cache_func is the actual function wrapped: if it is executed, it does not only return the value, but also 1) assign it to the cache variable and 2) set the validity flag to true

Invalidation is easy, can be done outside runtimes.

Currently, an open discussion with Alex, Martin from google hanging on that https://groups.google.com/a/tensorflow.org/forum/?utm_medium=email&utm_source=footer#!msg/discuss/yaYid8h85s0/ch_8f6oNAQAJ EDIT: the discussion resolved in this being the recommended way. Benchmark show an equal or even better performance compared to the feed_dict as shown here

Gradients

Gradients can be done with custom_gradient, sure, heavy work, technical, but doable. We just need to "recursively" (or max to order 2) overwrite the gradient of the return value which, again, should be a cache.

Tracking dependenies

zfit tracks the dependencies of an object in general, which overestimates the dependencies, as certain functions may only depend on a subset of the dependencies (e.g. pdf does not depend on a possible yield). A smarter way would be to watch during the execution of cache_func which CacheInvalidater` is touched and collects them this way.

Possible problems:

spflueger commented 4 years ago

This part is for sprouting possible ideas and sharing my thoughts: In ComPWA the FunctionTree handles the tracking of dependencies and triggering of recalculating via a "bidirectional graph" and an observer pattern. I don't like the bidirectional graph solution, so I would not recommend this. Based on the two use cases below, my current "best solution" would be to apply caching dynamically only at the optimization stage (where it is actually needed, as a pre-optimization step). My current ideas were to work with several pipelines (part calculation of a graph), where each would represent a case when a certain group of parameter changed. This might not be very feasible for many parameters though, since the number of pipelines grows non linearly. I understand that tensorflow determines the design, but maybe some of the above info is useful.

Important is that both the data and parameters of the model are regarded as inputs and can change values to trigger recalculations. To trigger the recalculation the set_value method is handy.

For the pwa project (probably quite general) the use case is two fold:

  1. Simulation of data: parameters stay fixed, data changes
  2. Fitting/Optimization: parameters change, data stays fixed

On which level the caching would be required? I think I would be best if this would be a general feature like a decorator to a node in a graph. Then a user can either specify manually which nodes should be cached, or some intelligent code can determine that automatically (possible future update). E.g. cut of constant parts of the graph and cache the values at the highest point

jonas-eschle commented 4 years ago

That sounds good, and similar to what I had in mind. The idea is to cache it functionally:

result = some_expensive_calcs(...)
result = self.cache_value(result, deps=...)
return result

where the deps can be "manually" entered or (by default?) self.get_deps(), which returns all the dependencies that the pdf depends on.

This can easily be added as an aspect with a decorator to a method as well.

This would imply that any parameter that the PDF depends on would re-trigger a calculation (or dataset).

Which is maybe not the optimal, but a pretty good way I think. Guess we can start out with this.

On the pre-optimization: I was thinking of a more general 'preoptimize', 'precalc', 'optimize' or similar method, which can be invoked before the first call to a pdf and this performs multiple optimizations, also on subpdfs (e.g. in a sum). Would be nice if we have a bunch of "similar" methods along the packages.