sandialabs / WecOptTool

WEC Design Optimization Toolbox
https://sandialabs.github.io/WecOptTool/
GNU General Public License v3.0
13 stars 22 forks source link

Feature request: Avoid repeating identical calculations (speedup) #306

Open rebeccamccabe opened 10 months ago

rebeccamccabe commented 10 months ago

Feature description. A way to cache functions so that repeated evaluations of identical inputs are not recalculated during optimization. Or, a way to remove the time-consuming error-checking from functions during optimization and instead validate the inputs beforehand.

Issue addressed. Inspired by #305, I tried profiling wecopttool to see what is taking up the most time. Below are the results of a sweep of optimizations with 27 different impedances, with a regular wave, nsubsteps = 16, nfreq = 5, use_grad=False.

image

Surprisingly, a substantial portion of time (2.33 / 6.12 = 38%) of the time is spent repeatedly (6244 times) evaluating the wave_excitation function. A lot of the time in this function comes from the error checking (subset_close, allclose). This seems strange because this function is constant in the optimization, so could be computed a single time and reused for the whole optimization.

Describe alternatives you've considered In another project, I've used functools.cache which can cache function evaluations when inputs are identical, with just a decorator. I managed to get it working on WecOptTool, but it was a bit more complicated than I expected because xarrays aren't hashable, so I had to use dask (which conveniently is already a WecOptTool dependency through wavespectra). You can see from the profiling results that now wave_excitation only gets called once, and the other 6243 times, the value is cached.

image

Unfortunately, this doesn't work because the dask hashing (which takes around 1ms per argument so 2ms per function call) is slower than the <1ms time it takes to actually execute the wave_excitation function, so the overall time goes up substantially. Now most of the time is in the init function, which calls dask.base.tokenize to generate the hash.

Describe the solution you'd like This would be beneficial to users who plan to perform many many optimizations and therefore speed is important. Perhaps there is a better way to do the hashing that would make my solution viable. ie:

Or, ditch hashing altogether and do something more manual:

Interest in leading this feature development? Sure, but I have reached the limit of my knowledge of hashing, so someone with more CS experience might be more useful than me if that approach is preferred

Additional information The timings were done using main 598e875cd48751ddb652657391205f6e208714f0 rather than 2.6.0 because 2.6.0 has a transpose function in the wave_excitation function which actually seems to slow it down a lot. I used this xarray issue and this stackoverflow page to figure out the dask hashing.

My code for dask hashing, inserted into core.py:

class HashWrapper:
    def __init__(self, x) -> None:
        self.value = x 
        with dask.config.set({"tokenize.ensure-deterministic":True}):
            self.h = dask.base.tokenize(x)
    def __hash__(self) -> int:
        return hash(self.h)
    def __eq__(self, __value: object) -> bool:
        return __value.h == self.h

def hashable_cache(function):
    @functools.cache
    def cached_wrapper(*args, **kwargs):
        arg_values = [a.value for a in args]
        kwargs_values = {
            k: v.value for k,v in kwargs.items()
        }
        return function(*arg_values, **kwargs_values)
    @functools.wraps(function)
    def wrapper(*args, **kwargs):
        shell_args = [HashWrapper(a) for a in args]
        shell_kwargs = {
            k: HashWrapper(v) for k,v in kwargs.items()
        }
        return cached_wrapper(*shell_args, **shell_kwargs)

    wrapper.cache_info = cached_wrapper.cache_info
    wrapper.cache_clear = cached_wrapper.cache_clear

    return wrapper

plus the @hashable_cache decorator applied above the wave_excitation function definition. Hypothetically if the caching were actually faster, the decorator could be added to any function with dask-hashable inputs (ie np arrays and xarrays, but not custom classes unless a custom tokenize method were written).

cmichelenstrofer commented 9 months ago

This is great! I think we can make the wave excitation function always return some pre-calculated value. This would need to be setup when we call WEC.from_impedance or WEC.from_bem.

ryancoe commented 9 months ago

Note that we may have done it this way because we were trying to allow for nonlinear excitation. However, this is only really possible using the default WEC.__init__ method. Solution: for the other static init methods, make it so excitation is only calculated once.

@rebeccamccabe - Thanks for doing this. Can you share with us what tools you used to do the profiling so we can explore this a bit more?

rebeccamccabe commented 9 months ago

I used the default profiler that comes with spyder ide. I had to turn autograd off to get profile results that were interpretable, because otherwise most computations show up as coming from the autograd box function wrapper.