dask / dask-glm

BSD 3-Clause "New" or "Revised" License
75 stars 46 forks source link

Approximate exponents #23

Open mrocklin opened 7 years ago

mrocklin commented 7 years ago

Fun fact: the current ADMM implementation can spend almost half of it's time computing np.exp

Script

from dask import persist
import dask.array as da
import numpy as np
from dask_glm.logistic import admm
from dask_glm.utils import make_y

N = 1e7
M = 2
chunks = 1e6
seed = 20009

X = da.random.random((N, M), chunks=(chunks, M))
y = make_y(X, beta=np.array(list(range(M))), chunks=chunks)

X, y = persist(X, y)

%%prun
import dask
with dask.set_options(get=dask.get):
    beta = admm(X, y)

Profile results

I use a wrapped version of np.exp just so that it shows up in profile results.

         565331 function calls (532847 primitive calls) in 93.202 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     2254   41.543    0.018   41.543    0.018 utils.py:25(exp)
     1127   30.614    0.027   52.086    0.046 logistic.py:349(logistic_loss)
     1127   13.122    0.012   40.358    0.036 logistic.py:364(logistic_gradient)
     1127    6.267    0.006   27.191    0.024 logistic.py:343(sigmoid)
     1130    0.793    0.001    0.793    0.001 {method 'reduce' of 'numpy.ufunc' objects}
        8    0.078    0.010    0.079    0.010 executionengine.py:100(finalize_object)
     1127    0.052    0.000   40.409    0.036 logistic.py:373(proximal_logistic_gradient)
     1127    0.049    0.000   52.135    0.046 logistic.py:358(proximal_logistic_loss)
        8    0.043    0.005    0.043    0.005 passmanagers.py:94(run)
       30    0.034    0.001   93.040    3.101 lbfgsb.py:205(_minimize_lbfgsb)
     2254    0.031    0.000   41.582    0.018 dispatcher.py:152(__call__)
     1127    0.029    0.000   92.998    0.083 lbfgsb.py:277(func_and_grad)
26957/12971    0.027    0.000    0.083    0.000 {method 'format' of 'str' objects}
     1127    0.021    0.000   52.417    0.047 optimize.py:290(function_wrapper)
    80776    0.019    0.000    0.030    0.000 {built-in method isinstance}
     1127    0.013    0.000    0.814    0.001 fromnumeric.py:1743(sum)

Comments

Here are some comments from @stuartarchibald

How many exp() are you doing at a time? If you are doing many, are the input values of similar magnitude? Do you care about IEEE754 correctness over NaN and Inf?

When I've deal with this previously, most of the time spent in a loop over exp() was in a) feraiseexcept or similar dealing with inf/nan b) branch misprediction, the split points are usually f(log_e(2)) IIRC.

a) goes away if you don't care b) goes away if you can guarantee a range or aren't too bothered about inaccuracy out of range. Remez algs are usually used to compute the coefficient table for a polynomial and a bunch of shifts are done to get values into appropriate range. c) if you have a load of values to compute, perhaps try Intel VML?

cc @seibert @mcg1969

mrocklin commented 7 years ago

@jcrist this may interest you

mcg1969 commented 7 years ago

In my experience we don't need to care about Inf/NaN checking, but we do need to make sure that it returns Inf if the input is too large.

cicdw commented 7 years ago

I personally think that it could be useful to use inexact / approximate computations for a large piece of the algorithm and then use exact computations for the final steps. A high-level overview of what I mean:

def algo(X,y):
    beta_hat = iterate_with_approx_exp_until_convergence(X, y)
    beta_hat = iterate_with_exact_exp_until_convergence(X, y, init=beta_hat)
    return beta_hat

The first piece will hopefully be faster, and the second piece ensures that we are still solving the true problem.

mrocklin commented 7 years ago

It looks like exp appears in the following two places:

def sigmoid(x):
    return 1 / (1 + exp(-x))

def logistic_loss(beta, X, y):
    '''Logistic Loss, evaluated point-wise.'''
    beta, y = beta.ravel(), y.ravel()
    Xbeta = X.dot(beta)
    eXbeta = np.exp(Xbeta)
    return np.sum(np.log1p(eXbeta)) - np.dot(y, Xbeta)

In the sigmoid case it seems like we only care about accuracy close to 0 because anything farther away will get squashed.

In the second case of log(1 + exp(x)) might we seem to also only get interesting behavior around 0. For other cases this seems to approach y = x or y = 0 quickly.

mrocklin commented 7 years ago

It's also worth noting that the example above is an extreme case where there are only two columns. As the number of columns increases we expect the matvec to dominate over the exp.

mcg1969 commented 7 years ago

There are plenty of pieces of advice that Stuart has offered above that would require no numerical compromise and would be worth considering. But for the rest I'm going to maintain by spoil sport position here.

I'm opposed to doing anything that returns an approximate result. Statisticians want what they want; heck, most modelers want what they want. And if that's a logistic regression, that means the global minimum of a particular cost function.

I also feel that an attempt to design a hybrid approach that attempts to preserve global convergence while exploiting approximations in early stages introduces a lot of distracting research and speculation.

As far as I know (and I could be wrong) a dependence on approximate exponentials is unprecedented in this area. That's true of every innovation, certainly. But I think it is worth asking: are any other well-respected or widely used GLM implementations approximating the exponential? And if not, why not? I took a look at H2O and didn't see any evidence they used anything but standard library exp. (That said, If we prove competitive or better with the exact model, and we can demonstrate significant improvements on top of that with an approximation, that would be attractive.)

Obviously, I'm all talk and no action right now (though I was actively working on test generation today). I'm not trying to exercise power over the project I don't have. I just don't think the market is asking for approximations for the sake of speed, when there ought to be myriad ways to improve upon the current best workflow practices, of which the algorithm is just a part.

mrocklin commented 7 years ago

My understanding is that choices like sigmoid were fairly qualitative to start with, and so outputs should be relatively robust to approximation here. Also, approximations of exp can be very very accurate if constrained to certain ranges.

There is evidence that people do approximate exp in deep learning communities.

However, benchmarking shows that this is only valuable in the small-number-of-columns case. If common use cases are well over 10 columns then the matvecs take over in cost.

mcg1969 commented 7 years ago

My understanding is that choices like sigmoid were fairly qualitative to start with, and so outputs should be relatively robust to approximation here.

There is evidence that people do approximate exp in deep learning communities.

That's absolutely right! And yet I consider it completely irrelevant to this discussion. So it seems our disagreement here may be more philosophical than technical.

There is no doubt that a larger modeling effort might be robust to an approximation. The very act of selecting a model to apply to a set of data implies approximation and heuristic. How many people have chosen least square simply because it's something they know how to compute, without any consideration as to whether or not the error distribution is approximately IIID Gaussian? And with a binary model, there are several common alternatives to the logistic model. When to choose one over the other is not always going to be clear, and many will just settle for what they know.

But once I've settled on a model, I want what I want. If I've decided that's going to be a logistic model, I'm going to find suspect a package that insists on offering me an approximation. Or perhaps it's not so much that we would insist on it, but offering it prominently might communicate that an approximation is necessary to achieve decent performance---when no other package I'm aware of feels the need to make that claim.

mrocklin commented 7 years ago

I hear what you're saying, and I appreciate the aversion to approximations, however I think that this may also be approximate in a similar way to how floating point arithmetic is approximate. I suggest that we just try things out on real data and see if it has value in practice. I'm not suggesting that we go down this path, I'm suggesting that we explore it. I think it's cheap to try. I'm entirely happy to back off from it if it proves unfruitful or significantly affects accuracy.

I suspect that with a couple hours of work we could create an exp_approx implementation in numba that was arbitrarily accurate around the range relevant for sigmoid. We would then have the option to play with this function while we profile and benchmark different problems and see if it has positive or negative effects.

mcg1969 commented 7 years ago

Get me to within sqrt(eps) of the optimum (for the common measures thereof) and I won't complain.

hussainsultan commented 7 years ago

For ADMM, my understanding is that you don't have to be precise for local updates for the solution to converge, which leads me to believe that some numeric testing will be helpful. @moody-marlin thoughts?

mcg1969 commented 7 years ago

There's quite a bit of literature about inexact search methods. For instance, a typical gradient/Newton?BFGS descent method, you'll still see progress towards the optimal solution as long as your descent directions exhibit negative inner product with the gradient (and some other technical conditions). With large-scale Newton methods in particular you can sometimes get significant savings by solving the Newton system iteratively and inexactly.

But approximate search directions need to be distinguished from approximations of the models themselves.

stuartarchibald commented 7 years ago

This is an interesting discussion, I can't really comment on the domain specific use case of what e.g. modellers want, however, as an engineer I've some thoughts over the specific issue over exp() speed. First, I think performance is going to be to some extent dependent on the level of generalisation in the exp() function and how much that function can make use of the hardware vector units. The exp() in any libm is very general, it has to operate over the whole numerical range and handle corner cases like NaN, Inf and denormals, it also only ever operates in a scalar context. By the nature of merely handling the edge cases without concern of how to actually compute values for exp(), a scalar implementation will inherently contain numerous branches, and also will naturally use scalar instructions.

Given exp() is required in this case to operate over a large amount of data, vector instruction use perhaps mixed with some data specific specialisation may help. In approximate order of effort mixed with specialisation required...

mrocklin commented 7 years ago

@stuartarchibald which of these techniques would work well in a Python + Numba project that is strictly open source (no proprietary code).

inati commented 7 years ago

Sigmoid is very smooth, have you considered a lookup table?

inati commented 7 years ago

Sigmoid has a neat taylor series expansion, super fast to evaluate and reasonably accurate even with a pretty rough lookup table.

My interpolation is pretty lame, but just to show that one of you actually applied math types could come up with something much more accurate.

I'm pretty sure that any reasonable rational function of exponents of polynomials will behave in a similar way.

How much precision do you need for function evals in GLM? 3 digits? or more?

stuartarchibald commented 7 years ago

@mrocklin Enumerating the above techniques as 1-5.

  1. Proprietary so no.
  2. Proprietary so no.
  3. This is doable under your constraints and is not a compromise of precision, just a performance hack. Like (in vaguely C):
    for(size_t i=0; i<n; i++)
    exp_x[i] = exp(x[i])

    goes to

    
    void exp_func(double * input,  double * output, size_t n)
    {
    for (size_t i=0; i<n; i++)
    {
    // some inlined exp alg that's undergone the unswitch
    // possibly add in tiling here too, will need vector instruction support
    }
    }

exp_func(x, exp_x);


4. Is a "compromise" but only in that the necessary approximating polynomial is now approximating the sigmoid opposed to the `exp()` function. Whilst not a general solution, as other functions are likely to rely on `exp`, as noted previously the sigmoid is smooth, and is also odd, and so is likely to be amenable to such techniques.
5. Language independent.
thrasibule commented 7 years ago

One MIT licensed implementation for really fast exponential here: https://github.com/jhjourdan/SIMD-math-prims Looks like you get 20x speedup vs expf.