pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.02k stars 21.99k forks source link

Function request: logerfc, logerfcx special functions #31945

Open cossio opened 4 years ago

cossio commented 4 years ago

🚀 Feature

Implement the erfcx(x) special function, which computes exp(x^2) * erfc(x) in a numerically stable way. Also for convenience, add logerfc(x) = log(erfc(x)) and logerfcx(x) = log(erfcx(x)).

erfcx is available in many numerical packages, such as Matlab, Julia, SciPy R, and others.

From erfcx it is easy to obtain logerfc and logerfcx, but this involves a conditional which can be slow in pure Python code. So I recommend adding logerfc and logerfcx as well, which can be implemented as:

def logerfc(x): 
    if x > 0.0:
        return log(erfcx(x)) - x**2
    else:
        return log(erfc(x))

def logerfcx(x):
    if x < 0.0:
        return log(erfc(x)) + x^2
    else:
        return log(erfcx(x))

Motivation

These special functions are very useful whenever we have to work with truncated normal distributions.

Related: https://github.com/pytorch/pytorch/issues/2129, https://github.com/pytorch/pytorch/issues/32293

cc @mruberry @rgommers @heitorschueroff

zou3519 commented 4 years ago

Seems reasonable especially if all those other libraries have it.

cossio commented 4 years ago

@zou3519 I think I could implement this taking inspiration from the algorithms used in the above libraries.

If I wanted to implement add this myself, do you have any suggestions on what I should do? Maybe there are some PRs I can look at where people implement special functions in PyTorch? Would it have to be in C++ or it could be pure Python code?

cossio commented 4 years ago

I did a Python based implementation that calls in to scipy:

class ErfcxFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        input_np = input.detach().numpy()
        result_np = scipy.special.erfcx(input_np)
        result = input.new(result_np)
        ctx.save_for_backward(input, result)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        input, result = ctx.saved_tensors
        g = -2 / np.sqrt(np.pi) + 2 * input * result
        return g * grad_output

erfcx = ErfcxFunction.apply

def logerfc(x):
    return torch.where(x > 0, erfcx(x).log() - x**2, x.erfc().log())

def logerfcx(x):
    return torch.where(x < 0, x.erfc().log() + x**2, erfcx(x.log()))

Any comments are welcome. With some guidance into the pytorch source code I can turn this into a PR.

zou3519 commented 4 years ago

@cossio,

I'm not sure what erfcx does under the hood, but PyTorch does not have a numpy/scipy dependency so we cannot take a pull request if it is implemented as shown.

Given that erfcx is a pointwise (it operates on elements in an element-wise fashion) function, we should put it in here: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/UnaryOps.cpp , in C++. It's possible to implement it in python but we prefer C++ implementations unless the implementation is trivial.

I'm happy to review any pull requests on this subject.

french-paragon commented 4 years ago

I wrote a dummy implementation for log(erfc(x)) in pytorch, with polynomial approximations for large positive x. There's still a numpy dependency, but just to get some math constant in my project. And that one should be trivial to implement in c++ and cuda, as both implement erfc function already. If it interests someone I can try to implement a proper c++ version. I just don't know if someone who is better at math than me would have some reserve for my approximations.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 27 14:01:28 2020

@author: laurent
"""

import torch
from torch.autograd import Function

import numpy as np
from scipy.special import erfc

class logErfc(Function) :
    """
    An implementation of log(erfc(x)) with polynomial approximations for x leading to infs or nans. This should work for a large range of float32 values.
    """

    @staticmethod
    def forward(ctx, x) :

        ret = torch.log(torch.erfc(x))
        mask = torch.isinf(ret) | torch.isnan(ret)
        ret[mask] = torch.tensor(np.log(erfc(10.)), dtype = x.dtype, device = x.device) - x[mask]**2 + 100

        ctx.save_for_backward(x)

        return ret 

    @staticmethod
    def backward(ctx, grad_output) :

        x, = ctx.saved_tensors

        delta = -2./np.sqrt(np.pi) * 1/torch.erfc(x) * torch.exp(-x**2)
        mask = torch.isinf(delta) | torch.isnan(delta)
        delta[mask] = -2*x[mask]

        return grad_output*delta
cossio commented 3 years ago

@french-paragon It would be nice to have this!

mruberry commented 3 years ago

cc @kshitij12345, more requests to track in https://github.com/pytorch/pytorch/issues/50345.

fyi @cossio, we now have the torch.special namespace in nightlies and @kshitij12345 is adding ops to it for the 1.9 release. We won't cover every torch.special op in the 1.9 release (or maybe ever), but these requests are helpful in prioritizing which ops we do add.

cossio commented 3 years ago

Thanks @mruberry! Hope this one gets added.