QuantEcon / QuantEcon.py

A community based Python library for quantitative economics
https://quantecon.org/quantecon-py/
MIT License
2k stars 2.25k forks source link

[Update] ``Cartesian`` Module #143

Open mmcky opened 9 years ago

mmcky commented 9 years ago

Review cartesian module and update as required:

Current discussion from @albop: @oyamad , @mmcky : On the cartesian function, the reasoning was that the loops and list comprehensions are made over the number of dimensions, so that they are probably only a few iterations. Appart from the loop over _repeat_1d, it seems likely that everything else is going to stay in python mode, with little performance gain if any. This was just my feeling though, I didn't test it.

mmcky commented 9 years ago

It looks like this has been moved to gridtools.py. This should then be applied to gridtools.py

Can cartesian.py now be safely deleted?

rht commented 5 years ago

line_profile to see if further numbafication is useful or not

With

from numpy import linspace
x = linspace(0,20,100)

tic = time.time()
prod = cartesian([x,x,x])

correct = True
for i in range(999):
    n = prod[i,0]*100+prod[i,1]*10+prod[i,2]
    correct *= (i == n)
print(time.time() - tic)

99.9% happens in a numba-ified function _repeat_1d

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    14                                           @profile
    15                                           def cartesian(nodes, order='C'):
    16                                               '''
    17                                               Cartesian product of a list of arrays
    18                                           
    19                                               Parameters
    20                                               ----------
    21                                               nodes : list(array_like(ndim=1))
    22                                           
    23                                               order : str, optional(default='C')
    24                                                   ('C' or 'F') order in which the product is enumerated
    25                                           
    26                                               Returns
    27                                               -------
    28                                               out : ndarray(ndim=2)
    29                                                   each line corresponds to one point of the product space
    30                                               '''
    31                                           
    32         1         18.0     18.0      0.0      nodes = [np.array(e) for e in nodes]
    33         1          8.0      8.0      0.0      shapes = [e.shape[0] for e in nodes]
    34                                           
    35         1          3.0      3.0      0.0      dtype = nodes[0].dtype
    36                                           
    37         1          2.0      2.0      0.0      n = len(nodes)
    38         1        103.0    103.0      0.0      l = np.prod(shapes)
    39         1         31.0     31.0      0.0      out = np.zeros((l, n), dtype=dtype)
    40                                           
    41         1          3.0      3.0      0.0      if order == 'C':
    42         1         65.0     65.0      0.0          repetitions = np.cumprod([1] + shapes[:-1])
    43                                               else:
    44                                                   shapes.reverse()
    45                                                   sh = [1] + shapes[:-1]
    46                                                   repetitions = np.cumprod(sh)
    47                                                   repetitions = repetitions.tolist()
    48                                                   repetitions.reverse()
    49                                           
    50         4         20.0      5.0      0.0      for i in range(n):
    51         3     443306.0 147768.7     99.9          _repeat_1d(nodes[i], repetitions[i], out[:, i])
    52                                           
    53         1          2.0      2.0      0.0      return out
rht commented 5 years ago

98 appears to have been resolved.