PSLmodels / ParamTools

Library for parameter processing and validation with a focus on computational modeling projects
https://paramtools.dev
MIT License
19 stars 14 forks source link

Parameters class object and numba jit compilation slowdown #82

Closed rickecon closed 5 years ago

rickecon commented 5 years ago

QUESTION: Can you verify that the numba package has reduced jit speed-up capacity in the presence of Python objects? And if so, do you know of any work-arounds? I really want to use the ParamTools approach of creating a parameters class object for passing parameters throughout the functions of a model. However, I think there is evidence provided below that the parameters class object creates a computational slowdown.

I am a huge fan of ParamTools. We use it very profitably in the OG-USA model. However, I think that the @numba.jit decorator is not able to efficiently deal with Python classes. I get a nearly 2x slowdown in a particular computation when the only change I make is to pass parameters via a class (as is done with ParamTools) versus passing the parameters via a standardized tuple.

Some evidence for this being an issue is in this Numba Issue ("#3907 Depricate object mode") from March 27, 2019 as will as this UMAP project issue thread ("#252 Numba warnings") from June 11, 2019. Both threads state that Python classes force numba to use the forceobj=true compilation method and not be able to use the more efficient nopython compilation method.

I have posted my code that I used to conduct a horse race between two methods of passing parameters while using numba in the PubDebtNegShocks/code/Evans2020 directory of my PubDebtNegShocks repository. Feel free to clone the repository to verify this test.

Method 1: Parameters class argument passing (avg comp time: 2 min 44 sec) In the first method, I create a parameters class object in the PubDebt_parameters.py file. This is called by the executing PubDebt_sims.py script in line 65. This script runs a single 20-period simulation of a two-period-lived overlapping generations model that has an ugly Euler equation with two integrals that must be solved. See equations (5) or (26) in the documentation for this model. The simulation is carried out with functions in the PubDebt_funcs.py module, each of which passes arguments using the parameters class object p and each of which has the @numba.jit(forceobj=True) decorator. The forceobj=True restricts numba from using the nopython default compilation method, which it would be forced to move away from anyway. This argument suppresses a bunch of warning messages. The average computation time for one 20-period simulation is 2 minutes and 44 seconds.

Method 2: Single standard tuple argument passing (avg comp time: 1 min 29 sec) In the second method, I do NOT use the parameters class module. The only change in my executing script PubDebt_sims_2.py is that I manually declare all the parameter names and values and then I pass them throughout the functions for the simulation in the PubDebt_funcs_2.py using a standardized tuple called mod_args in line 125 of PubDebt_sims_2.py, and called either args or p_args in the PubDebt_funcs_2.py module. Similar to the previous case, each function in PubDebt_funcs_2.py has the @numba.jit(forceobj=True) decorator. The average computation time for one 20-period simulation is 1 minutes and 29 seconds.

@hdoupe @jdebacker

jdebacker commented 5 years ago

@rickecon What I've done in models with a parameters class like ParamTools is to unpack the parameters when passing to a @numba.jit decorated function. I am not "jitting" all functions, only those where the decorator would significantly improve the performance of that function, so the parameters class is used to carry parameters around most of the model.

Also, I know that Tax-Calculator uses the Numba jit decorator and a ParamTools-like parameters class. I don't know exactly how it's handling both together, but you might look at the iterate_jit function to see an example where the Policy class (containing the policy parameters) is used with the jit decorator. It's not the most clear function, but it looks like the parameters in the class are being dumped into a list before being passed to the jitted function.

rickecon commented 5 years ago

Thanks @jdebacker . I got my computation time on both method's 1 and 2 both down to around 50 seconds by only @jit-ing the functions that are directly targets of the numerical integration commands. These two functions get repeated thousands of times. For each guess of an agent's savings, the integral needs to be solved by quadrature (iterative process).

This tells me that the problem is not the parameters class as much as it is the complexity of the solution method. That is good news for ParamTools. Moral of the story for me is to use @jit intelligently and sparingly.

hdoupe commented 5 years ago

Thanks for opening (and resolving) the issue @rickecon. This made me think of something: perhaps, there's a simpler way to call functions with parameters classes like paramtools.Parameters than what Tax-Calculator uses with its iterate_jit decorator. Python provides an inspect module that makes it easy to check out Python objects and see what they do. In particular, you can get a function's arguments with inspect.signature(some_function). Using this information, the values corresponding to the function argument names can be pulled from a Python class and used as arguments to this function.

If we have a function add:

def add(a, b):
    return a + b

Then, inspect.signature tells us that add takes two arguments, a and b:

sig = inspect.signature(add)
dict(sig.parameters)

# {'a': <Parameter "a">, 'b': <Parameter "b">}

We can then call the parameters class defined below like this:

params = Parameters()
add(
    getattr(params, "a"),
    getattr(params, "b")
)

# array([[3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4]])

Putting it together, we can write a function that takes a parameters class and a function with arbitrary arguments and call it with using the values on the parameters class:

def call(params, func):
    sig = inspect.signature(func)
    t = []
    for arg in sig.parameters:
        t.append(
            getattr(params, arg)
        )
    return func(*t)
call(params, add)

# array([[3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4]])

This seems to work with jit-ed functions, too, but I've only tried it on this toy problem:


@numba.jit
def addjit(a, b):
    return a + b

call(params, addjit)

# array([[3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4],
#        [3.5, 9.3, 7.4]])

Full example is available as a notebook here: https://github.com/hdoupe/ParamTools-Examples/blob/master/notebooks/Jit.ipynb. @rickecon @jdebacker I'm curious to hear how something like this would work for your use cases.

example class

class Parameters(paramtools.Parameters):
    defaults = {
        "a": {
            "title": "A",
            "description": "",
            "type": "int",
            "value": [
                {"label1": 0, "label2": "one", "value": 1},
                {"label1": 0, "label2": "two", "value": 2},
                {"label1": 0, "label2": "three", "value": 3},
            ]
        },
        "b": {
            "title": "B",
            "description": "",
            "type": "float",
            "value": [
                {"label1": 0, "label2": "one", "value": 2.5},
                {"label1": 0, "label2": "two", "value": 7.3},
                {"label1": 0, "label2": "three", "value": 4.4},
            ]
        },
        "schema": {
            "labels": {
                "label1": {
                    "type": "int", 
                    "validators": {"range": {"min": 0, "max": 10}}
                },
                "label2": {
                    "type": "str",
                    "validators": {"choice": {"choices": ["one", "two", "three"]}}
                }
            }
        }
    }
    array_first = True
    label_to_extend = "label1"