ez2rok / recursion-visualizer

Visualize recursive functions with beautiful animations
https://ez2rok.github.io/recursion-visualizer/
Apache License 2.0
0 stars 0 forks source link

Cache functions with non-hashable arguments #2

Open ez2rok opened 2 years ago

ez2rok commented 2 years ago

Goal

Behind the scenes, 'RecursionVisualizer' works by using a dictionary to cache the arguments seen in the recursive function. This relies upon the arguments of the recursive function being hashable. However, sometimes these arguments are lists which aren't hashable.

My goal is to have 'RecursionVisualizer' work for all recursive functions, even if they have non-hashable arguments such as lists.

The Issue

For example, consider this implementation of the 0-1 knapsack problem:

@RecursionVisualizer()
def knapsack(capacity, weights, values, i):
  if i == 0 or capacity == 0:
    return 0
  if weights[i-1] > capacity:
    return knapsack(capacity, weights, values, i-1)
  return max(values[i-1] + knapsack(capacity-weights[i-1], weights, values, i-1), 
             knapsack(capacity, weights, values, i-1))

which could be called by writing

weights = [10, 20, 30]
values = [60, 100, 120]
capacity = 50

knapsack(capacity, weights, values, len(values))

With the current RecursionVisualizer version 0.0.0 implementation, this outputs the error

Output exceeds the [size limit](command:workbench.action.openSettings?[). Open the full output data [in a text editor](command:workbench.action.openLargeOutput?fecc58be-ddca-464d-a168-fa0cdafcf410)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/Users/eitanturok/Desktop/projects/recursion-visualizer/index.ipynb Cell 12 in <cell line: 5>()
      [2](vscode-notebook-cell:/Users/eitanturok/Desktop/projects/recursion-visualizer/index.ipynb#X33sZmlsZQ%3D%3D?line=1) values = [60, 100, 120]
      [3](vscode-notebook-cell:/Users/eitanturok/Desktop/projects/recursion-visualizer/index.ipynb#X33sZmlsZQ%3D%3D?line=2) capacity = 50
----> [5](vscode-notebook-cell:/Users/eitanturok/Desktop/projects/recursion-visualizer/index.ipynb#X33sZmlsZQ%3D%3D?line=4) knapsack(capacity, weights, values, len(values))

File ~/Desktop/projects/recursion-visualizer/recursion_visualizer/recurse.py:105, in RecursionVisualizer.__call__.<locals>.memoized_func(*args, **kwargs)
    103 self.depth += 1
    104 # if args not in self.cache:
--> 105 self.cache[args] = func(*args, **kwargs)
    106 self.depth -= 1
    108 # record node's output, finish time, history, and edge_label

/Users/eitanturok/Desktop/projects/recursion-visualizer/index.ipynb Cell 12 in knapsack(capacity, weights, values, i)
      [5](vscode-notebook-cell:/Users/eitanturok/Desktop/projects/recursion-visualizer/index.ipynb#X33sZmlsZQ%3D%3D?line=4) if weights[i-1] > capacity:
      [6](vscode-notebook-cell:/Users/eitanturok/Desktop/projects/recursion-visualizer/index.ipynb#X33sZmlsZQ%3D%3D?line=5)   return knapsack(capacity, weights, values, i-1)
----> [7](vscode-notebook-cell:/Users/eitanturok/Desktop/projects/recursion-visualizer/index.ipynb#X33sZmlsZQ%3D%3D?line=6) return max(values[i-1] + knapsack(capacity-weights[i-1], weights, values, i-1), 
      [8](vscode-notebook-cell:/Users/eitanturok/Desktop/projects/recursion-visualizer/index.ipynb#X33sZmlsZQ%3D%3D?line=7)            knapsack(capacity, weights, values, i-1))

File ~/Desktop/projects/recursion-visualizer/recursion_visualizer/recurse.py:105, in RecursionVisualizer.__call__.<locals>.memoized_func(*args, **kwargs)
    103 self.depth += 1
    104 # if args not in self.cache:
--> 105 self.cache[args] = func(*args, **kwargs)
    106 self.depth -= 1
...
--> 105 self.cache[args] = func(*args, **kwargs)
    106 self.depth -= 1
    108 # record node's output, finish time, history, and edge_label

TypeError: unhashable type: 'list'

The last line of this error message says: TypeError: unhashable type: 'list' because the knapsack function had the arguments weights and values which are both unhashable lists.

What I've Tried

I currently have two workarounds for this: 1) Simply make all data structures hashable. In this case, I could cast weights and values to tuples and everything would work fine:

weights = tuple([10, 20, 30])
values = tuple([60, 100, 120])
capacity = 50

knapsack(capacity, weights, values, len(values))

results in no errors. 2) Another option is that I could rewrite knapsack to have a wrapper function with the unhashable weights and values as it arguments and an internal function that simply has the index i and capacity as its arguments. This internal function is where the recursion is actually performed and because it is within the same scope of the outer wrapper function, it still has access to the weights and values lists without actually having these unhashable lists as arguments to its function.

def knapsack_wrapper(capacity, weights, values):
    "wrapper function "

    @RecursionVisualizer()
    def knapsack(capacity, i):
        "the actual function"

        if i == 0 or capacity == 0:
            return 0
        if weights[i-1] > capacity:
            return knapsack(capacity, i-1)
        return max(values[i-1] + knapsack(capacity-weights[i-1], i-1),
                   knapsack(capacity, i-1))

    return knapsack(capacity, len(weights))

This code works without any issues.

Moving Forward

Although these are both valid workarounds, I'm still looking for a more general purpose way to get around this problem.

ez2rok commented 2 years ago

The stack-overflow post Memoization recipe that allows non-hashable arguments recommends implementing a custom hash function that uses an object's __repr__ for hashing.

However, this is inefficient and still poses various issues. Several commenters on this post point out that even with hashing an object's `repr' you're not guaranteed that the hash is valid, since a dictionary of sets (for example) could still print the values in an arbitrary order.

ez2rok commented 2 years ago

Github user @slowli creates a gist which tries to turn non-hashable arguments into hashable arguments with the lines:

mapargs = lambda *args: args

def cached_fn(*args):
    mapped_args = mapargs(*args)
    if not mapped_args in cache:
        cache[mapped_args] = fn(*args)
    return cache[mapped_args]

Yet all mapargs does is cast args to a tuple which still is not helpful if the values in the tuple are unhashable, e.g. lists. I even tried implementing this and like I thought, it did not resolve the error.

ez2rok commented 2 years ago

I'm thinking of just writing some code to automatically cast lists to tuples in RecursionVisualizer. This has several advantages:

  1. This fix is quite easy to implement
  2. In the recursive functions RecursionVisualizer is likely to encounter, the values we change are likely going to be pointers or integer values. The lists that we recursively iterate through don't usually change. So I don't think creating an immutable tuple will ever affect the performance of 99% of recursive functions used with RecursionVisualizer.