data-apis / array-api

RFC document, tooling and other content related to the array API standard
https://data-apis.github.io/array-api/latest/
MIT License
205 stars 42 forks source link

RFC: special function extension #725

Open mdhaber opened 6 months ago

mdhaber commented 6 months ago

This RFC proposes adding a special function extension to the array API specification.

Overview

Several array libraries have some support for "special" functions (e.g. gamma), that is, mathematical functions that are broadly applicable but not considered to be "elementary" (e.g. sin). We[^1] propose adding a special sub-namespace to the array API specification, which would contain a number of special functions that are already implemented by many array libraries.

Prior Art

We begin with 25 particularly important special functions that are either already available for NumPy, PyTorch, CuPy, and JAX arrays or are easily implemented. Partial information about their signatures in these libraries is included in the table below; parameters that are less commonly supported/used are omitted.

Function numpy (scipy.special) torch (torch.special) cupy (cupyx.scipy.special) jax (jax.scipy.special)
log-sum-exp logsumexp(z1, z2) logsumexp(x) logsumexp(x1, x2) logsumexp(z1, z2)
logit logit(x) logit(x) logit(x) logit(z)
expit expit(x) expit(x) expit(x) expit(z)
log of normal CDF log_ndtr(z) log_ndtr(x) log_ndtr(x) log_ndtr(x)
normal CDF ndtr(z) ndtr(x) ndtr(x) ndtr(x)
normal CDF inverse ndtri(x) ndtri(x) ndtri(x) ndtri(x)
digamma digamma(z) digamma(x) digamma(x) digamma(x)
polygamma polygamma(n, x) polygamma(n, x) polygamma(n, x) polygamma(n, x)
multigammaln multigammaln(x, n) multigammaln(x, n) multigammaln(x, n) multigammaln(x, n)
log of absolute value of gamma gammaln(x) gammaln(x) gammaln(x) gammaln(x)
gamma gamma(z) - gamma(x) gamma(x)
gamma (incomplete lower, regularized) gammainc(x1, x2) gammainc(x1, x2) gammainc(x1, x2) gammainc(x1, x2)
gamma (incomplete upper, regularized) gammaincc(x1, x2) gammaincc(x1, x2) gammaincc(x1, x2) gammaincc(x1, x2)
log of absolute value of beta betaln(x1, x2) - betaln(x1, x2) betaln(x1, x2)
beta beta(x1, x2) - beta(x1, x2) beta(x1, x2)
beta (incomplete lower, regularized) betainc(x1, x2, x3) - betainc(x1, x2, x3) betainc(x1, x2, x3)
erf erf(z) erf(x) erf(x) erf(x)
erf complement erfc(z) erfc(x) erfc(x) erfc(x)
erv inverse erfinv(x) efinv(x) erfinv(x) erfinv(x)
zeta zeta(x1, x2) zeta(x1, x2) zeta(x1, x2) zeta(x1, x2)
binomial coefficient binom(x1, x2) - - -
exponential integral expi(x) - expi(x) expi(x)
generalized exponential integral expn(n, x) - expn(n, x) expn(n, x)
softmax softmax(z) softmax(x) softmax(x) nn.softmax(z)
log of softmax log_softmax(z) log_softmax(x) log_softmax(x) nn.log_softmax(z)

With the exception of log-sum-exp functions, which reduces along an axis, all work elementwise, producing an output that is the broadcasted shape of the arguments. The variable names shown are not necessarily those used by the referenced library; instead they are standardized with x/z/n denoting an arguments of real/complex/integer dtype.

Further information about these functions in other languages (C++, Julia, Mathematica, Matlab, and R) is available in this spreadsheet.

Proposal

The Array API specification would include the following functions in a special sub-namespace.

Function Array API (proposed) Name/interface change notes
log-sum-exp log_sum_exp(z, /, *, axis=-1, weights=None) Enforce naming consistency
logit logit(x, /) Unchanged (standard)
expit expit(x, /) Unchanged (standard)
log of normal CDF log_normcdf(a, b=None, /) Existing name is cryptic
normal CDF normcdf(a, b=None, /) Existing name is cryptic
normal CDF inverse normcdf_inv(p, /, *, a=None, b=None) Existing name is cryptic
digamma digamma(z, /) Unchanged (standard)
polygamma polygamma(n, x) Unchanged (standard)
multigammaln log_multigamma(x, n) Enforce naming consistency
log of absolute value of gamma log_abs_gamma(z, /, *, a=None, b=None, regularized=None) Existing name imprecise
gamma gamma(z, /, *, a=None, b=None, regularized=None) Interface generalized
log of absolute value of beta log_abs_beta(x1, x2, /, *, a=None, b=None) Existing name imprecise
beta beta(x1, x2, /, *, a=None, b=None) Interface generalized
erf erf(a, b=None, /) Interface generalized
erv inverse erf_inv(p, /, *, a=None, b=None) Interface generalized
zeta zeta(x1, x2=None, /) Unchanged (standard)
binomial coefficient binom(x1, x2, /) New to most libraries
exponential integral expinti(x, /) Existing name is cryptic
generalized exponential integral expintv(n, x) Existing name is cryptic
softmax softmax(z, /) Unchanged (standard)
log of softmax log_softmax(z, /) Unchanged (standard)

A few notes about the interface:

Where applicable, we find that these conventions generalize well to other special functions that might be added in the future.

Other notes about function selection:

Questions / Points of Discussion:

  1. Some array libraries implement special functions only for real arguments, but many applications require the the extension of these functions to complex arguments. We invite discussion about how to approach this. Can the standard specify that complex arguments should be accepted (with corresponding keyword names involving z rather than x) even if some libraries are not compliant initially?
  2. For functions with both log_ and _inv components in the name, the order of operations is ambiguous. For example, would log_normcdf_inv (which would be useful in statistics) be the logarithm of the inverse of normcdf or the inverse of the logarithm of normcdf?
    • One proposal for resolving the ambiguity would be to use attributes rather than function names to organize logarithms and inverse functions. For example, instead of separate functions normcdf, normcdf_inv, log_normcdf, and log_normcdf_inv, normcdf would have attributes normcdf.log and normcdf.inv, and normcdf.log would have an attribute normcdf.log.inv.
    • Another possibility is for log_ and inv_ to both be prefixes. However, _inv typically appears as a suffix in existing special function names, perhaps because the superscript $-1$ that denotes inversion often appears after the function symbol, e.g. $f^{-1}(x)$.
    • The natural alternative is for both _log and _inv to be suffixes. However, log typically appears as a prefix in existing function names, perhaps because this is how the function appears when typeset mathematically, e.g. $log(f(x))$.
  3. Consider the Python range function: it is natural for range(y) to denote a range with an upper limit of y and for range(x, y) to generate a range between x and y. However, if the arguments were allowed to be specified as keywords, it would be unclear how they should be named. The use range(y) suggests that the name of the first argument might be stop, but range(x, y) suggests that the name of the first argument should be start; assigning either name and allowing both positional and keyword specification leads to confusion. To avoid this ambiguity, range requires that the arguments be passed as positional-only. We run into a similar situation with our a and b arguments. After carefully considering many possibilities, we have suggested the following above:
    • Functions for which the first two arguments are a/b require that these arguments are positional-only.
    • Functions with other arguments before a/b require that these arguments are keyword-only.
  4. The downside of these conventions for a/b is that they are somewhat restrictive. Users cannot call normcdf(a=x, b=y) with keywords to be explicit, nor can they be call gamma(z, x, y) without keywords to be concise. A compromise would be to accept separate positional-only and keyword-only versions of the same argument, and implement logic to resolve the intended use. While this is anticipated to allow for both natural and flexible use, it would be somewhat more cumbersome to document and implement.
  5. The default value of the regularized argument of gamma is challenging to choose.
    • The incomplete gamma function (e.g. gamma(z, upper=y)) will typically be regularized, suggesting that a regularized=True default is more appropriate for this use case.
    • The regularized complete gamma function (e.g. gamma(z)) is identically 1, suggesting that regularized=False is more appropriate for this use case.
    • We can get the desired behavior by default in both cases by using regularized=None. When gamma is used as the complete gamma function (without a/b), regularized would be set to False, and when gamma is used as the incomplete gamma function (with a/b, regularizedwould be set to True. However, this is more complex to document than choosing either True or False as the default.
  6. Many of the argument names - and especially whether they should be positional-only or not - are up for debate. For instance, the two arguments of binom are not interchangeable, suggesting that some users might prefer to pass arguments by keyword. On one hand, n and k would be reasonable names, since the binomial coefficient is often needed in situations that call for "n choose k". On the other hand, the names n and k are not entirely universal, and the function is extended for real arguments, whereas names n and k are suggestive of integer dtypes. Also, while a and b are concise names that are commonly used for lower and upper limits of integration, they are not as descriptive as lower/upper, and might be confused with the symbols commonly used for different arguments of the same function (e.g. beta). low/high, lo/hi, ll/ul, c/d have also been proposed.

[^1]: @steppi, @izaid, @mdhaber, @rgommers

izaid commented 6 months ago

Really happy to see the discussion get started. One point I'll add is that, of course, there are many other special functions that are widely used. Indeed, many of the ones important to me (like the Bessel functions) are not covered by the above list. Why? We started with a minimal set of functions that are easily implementable everywhere. It's not super helpful to propose a standard for a special function that other array libraries will not implement because it's too much effort.

This is why the task of converting SciPy's internal special function implementations into C++, see https://github.com/scipy/scipy/issues/19404, is relevant and important.

steppi commented 6 months ago

Thanks @mdhaber! Very nice write-up. Looking forward to the discussion.

rgommers commented 5 months ago

Thanks for all the hard work on this @mdhaber, @izaid and @steppi! I'll add a few initial thoughts:

  1. The scope of this proposal is large. It seems like the first thing to decide here is whether we're happy to add this extension, whether we want to indeed go for the consistent naming scheme, and what to do with complex dtype support.
  2. Re argument names and positional/keyword-only: the proposal is missing a bunch of positional-only / symbols, I assume by accident. All one-letter x/y/z/n should be positional-only. Other than that it seems fairly straightforward, aside from keyword-only a/b, that's too non-descriptive I think.
  3. There is an issue with the axis keywords in the reduction-like functions (logsumexp, softmax, log_softmax); PyTorch requires the user to specify it, while JAX isn't consistent with SciPy/CuPy (-1 vs. None). That will be a problem, since the semantics are going to be different between them in a way that's not easily resolvable with a deprecation.
  4. My impression on function names:
    • ndtr renaming is 👍🏼, that name is too awful to standardize.
    • I'm less sure about expinti/expintv, those are the only ones that are pretty unreadable imho.
    • gamma: collapsing 3 functions into one seems to mean that the default gamma(x) now needs to specify integration bounds? It's not entirely clear - the long explanation in point (5) in your write-up suggests that there may not be much gained from this, compared to staying with existing APIs.
    • logsumexp is pretty heavily used, and log_sum_exp is following a consistent naming scheme but probably not actually more readable (functions with 2 underscores tend to be slightly awkward - goes for log_abs_xxx too).
    • Overall, it'd be useful to categorize all the renames with reasons for them, e.g.:
      • "must do because introduction with existing names isn't possible due to semantic differences",
      • "should do because name is terrible",
      • "should do to generalize the function",
      • "should do for consistency in chosen naming scheme".

The binomial coefficient function (binom) does not seem to be implemented for PyTorch, CuPy, or JAX arrays, but the need is so fundamental that we wish to include it in the standard.

Such "not implemented" status has typically been a blocker for inclusion. For the linalg extension we also discussed a preliminary list, something like "if a library adds this, it must be with this signature and semantics". I'm not sure how fundamental binom actually is for real-world applications; the feature request for PyTorch was approved 3 years ago for example (https://github.com/pytorch/pytorch/issues/47841), but no one even commented since.

mdhaber commented 3 months ago

Following the numbers used in https://github.com/data-apis/array-api/issues/725#issuecomment-1881598250:

  1. I suppose I'll weigh in just to kick things off:
    • Yes, I think the extension should be added : )
    • My first inclination was to adopt the same names/interfaces as the major array libraries where there is already agreement. After much careful consideration with others, I think it's worth the effort to make the changes.
    • For complex dtype support, I think it's OK for the standard to specify that certain parameters must support complex input even if some implementations are going to be non-compliant initially. (Few array libraries were fully array API compliant initially, right?) There is a lot of recent work in SciPy toward sharing special function implementations across backends, which will help all libraries to become compliant[^1].
  2. I am not sure that the positional-only symbols were missing by accident, but I've added a few more. It is intentional not to make the arguments of polygamma, log_multigamma, and expintv positional-only.
  3. Good point, but I think we can resolve this in way that is mostly backward compatible even without requiring deprecations. IIUC, the standard would have all array libraries expose these functions in a namespace <library>.special. All libraries that we've studied except for PyTorch will need to add this namespace, so they get to start the interface from scratch. PyTorch would need to add an argument named axis to their existing functions (since they currently use dim, which has no default). Even if we choose to add axis with a default, this would not necessarily break existing user code which already specifies dim (which could take precedence over axis). Does that work?
  4. Re: names
    • expinti/expintv are not ideal names, but they are a little more explicit than expi and expn. Mathematically, these functions are represented by $Ei$ and $E_n$, so actually, I'm forgetting where expintv came from (@steppi @izaid?). One idea inspired by Mathematica and R is to call them expint_ei and expint_en. Would that be better?
    • Re: gamma collapsing several functions into one - no, gamma(z) would compute the good old gamma function as usual. I don't think the long explanation in my write-up suggests that there is not much to be gained; rather, it suggests that there are still decisions to be made. There is much to be gained, especially the eventual ability to evaluate the gamma function integral between arbitrary lower and upper limits of integration rather than relying on subtraction, which is less readable and can cause catastrophic cancellation.
    • I can live with logsumexp from the perspective that it is no longer used as a description of what the function does (in which case it would follow the convention) but it is the name of the function. logsumexp has become the name throughout the scientific Python ecosystem and even in other programming languages. (@steppi @izaid thoughts?)
    • I can work with @steppi and @izaid to add information about the rationale for the names. Rationale added to table above.
  5. Re: "Such 'not implemented' status has typically been a blocker for inclusion". Maybe we can submit PRs to implement them? Or (preferably), maybe that precedent can be changed? (I'm sure that others have made arguments for the latter before, but LMK if someone needs to make a case for a standard being prescriptive rather than solely descriptive.) As for whether binom in particular is important for real-world applications, comb/binom is used in scipy.interpolate, scipy.linalg, scipy.signal, and scipy.stats, for example. I am willing to admit that scipy.stats.ks2samp (in which binom is used) is not useful for real-world applications[^2], but the rest seem useful to me.

[^1]: For arguments that are allowed to be specified by keyword, we probably need to double-check that the names (n vs x vs z) are appropriate for the dtype we intend compliant implementations to accept. For positional-only arguments, libraries can name them in their documentation according to the input type they accept. [^2]: Or it probably shouldn't be used given the alternatives available.

izaid commented 3 months ago

@mdhaber Covered it all super well, but I'll chip in just a little.

  1. Yes, I also think the extension should be added. On the names, I think it's important to get this right and modernise names that no longer make sense (looking at you, ndtr). For complex dtypes, I don't think we should leave this out.

  2. For expinti and expintv, I think we were a little stuck. The current names are expi and expn, and expi especially feels like it should exp(i x) not some sort of integral. The issue is there is not a good abbreviation for "integral". If we use "it", we end up with expit, which is already a function (logistic sigmoid). If we use "int", it sounds like something related to "integer". So, for these two, I think we don't know what to do and are open to suggestions. The "v" in expintv is much more explainable: we generalised it to support floats and not just integers, an "v" is a usual marker for that.

As for logsumexp versus log_sum_exp, actually I'm keen to be consistent with the convention. In that particular case, I could live happily, but I would prefer to keep things similar. As for log_abs_xxx, consider the current situation in SciPy where loggamma is the logarithm of the gamma function and gammaln is the logarithm of the absolute value of the gamma function. That really should be changed, it's confusing.

Generally very happy to discuss! Think it's important to get this right.

oleksandr-pavlyk commented 3 months ago

I would stay close for Digital Library of Mathematica Functions, https://dlmf.nist.gov/6.2, and perhaps name expinti explcitly exp_integral_ei, like in Wolfram Language: https://reference.wolfram.com/language/ref/ExpIntegralEi.html

Verbosity is not an issue nowadays with Copilot and IDEs.

NeilGirdhar commented 2 months ago

Just curious, but would it make sense to copy or move logaddexp, expm1, and log1p to the special function extension? While current numpy users might expect these in the main namespace, new users of numpy would probably find it quirky that these special functions have been "promoted" to the main namespace.

asmeurer commented 2 months ago

We shouldn't move them. That would be a compatibility break with existing versions of the standard. It wouldn't be a big deal to duplicate them. There's a similar thing for some functions like matmul in the linalg extension.

NeilGirdhar commented 2 months ago

It might also make sense to ask whether there might eventually be a "neural network" extension that reflects the functions in jax.nn and torch.nn.functional. Would any of the above functions be better suited in such an extension? (I personally don't think so.)

kgryte commented 2 months ago

@NeilGirdhar Re: neural network extension. See https://github.com/data-apis/array-api/issues/158, which you previously commented on.

fancidev commented 2 months ago

Great additions! My two cents about some of the points:

+1 on supporting both lower and upper bounds of integration. The parameter convention of Python’s range function feels not the most straightforward, but I don’t have a better alternative either.

Are normcdf and erf linear transform of each other? If so, should we keep just one of them to keep the interface lean?

Where some implementation does not yet support complex argument of a function, does it make sense to standardize real argument first so that all implementations become compliant? Each implementation is then free to support complex argument.

Regarding the default value of axis, what is the default of those in the Array API? Using the same seems natural. If some implementation has a different default, their users just have to reckon that it is different in the Array API.

asmeurer commented 2 months ago

Regarding the default value of axis, what is the default of those in the Array API? Using the same seems natural. If some implementation has a different default, their users just have to reckon that it is different in the Array API.

Functions like sum default to axis=None, which means to reduce over the whole array.