jonasrauber / eagerpy

PyTorch, TensorFlow, JAX and NumPy — all of them natively using the same code
https://eagerpy.jonasrauber.de
MIT License
693 stars 39 forks source link

Have a decorator to wrap universal functions ? #34

Open eserie opened 3 years ago

eserie commented 3 years ago

In order to simplify the writting of universal functions it could be great to have a decorator function which hide the technical part of the code (convertion of input and output of the wrapped function/method). For example, the code:

def my_universal_function(a, b, c):
    # Convert all inputs to EagerPy tensors
    a, b, c = ep.astensors(a, b, c)

    # performs some computations
    result = (a + b * c).square()

    # and return a native tensor
    return result.raw

would become:

@eager_function
def my_universal_function(a, b, c):
    return (a + b * c).square()

In addition, we could add the feature that if the input tensors are already eagerpy tensors, then no convertion to raw format should done on the output tensors.

I wrote a prototype of such a decorator function. It should not work on any type of arguments and so its usage would require that the wrapped function has a rather "simple" signature (with args and kwargs constituted of tensors or nested containers with tensors on leaves: dict, list, tuple or namedtuple like containers).

Would you consider to have this feature in eagerpy?

jonasrauber commented 3 years ago

Would you consider to have this feature in eagerpy?

Yes, a nice generic decorator that can handle arbitrary number of arguments (and return values) would be great. I am pretty sure I thought about this before, but I cannot recall why I didn't do it.

In addition, we could add the feature that if the input tensors are already eagerpy tensors, then no convertion to raw format should done on the output tensors.

Have you seen the ep.astensor_ and ep.astensors_ functions (with the underscore)? They already do exactly that: https://eagerpy.jonasrauber.de/guide/generic-functions.html (see the examples at the end).

eserie commented 3 years ago

Thanks for your response! No, unfortunately I didn’t have seen the functions astensor and astensors (only astensor). It should definitively be a good starting point! I think I have a POC for a version of that function which could manage more general formats for inputs/outputs. I can make a try for integration in eagerpy and propose a PR in coming days if you are ok.

eserie commented 3 years ago

I show bellow a first POC that I wrote for the wrapper function eager_function (which is working). At this stage, it seems not totally trivial to me how to integrate it the code base. I would appreciate to have a first feedback from this code in order to know if we can go further in this direction.

import numbers
from collections import defaultdict
from functools import wraps
from typing import Any

import eagerpy as ep

def _tuple_as(template, data):
    data = list(data)
    try:
        # list, tuple case
        return type(template)(data)
    except TypeError:
        # named tuple case
        return type(template)(*data)

def _dict_as(template, data):
    """Create dictionary like data structure from template object.
    Parameters
    ----------
    template
        objecti used as template
    data
        data used to fill the created object.
    """
    if isinstance(template, defaultdict):
        return type(template)(template.default_factory, data)
    return type(template)(data)

def as_eager_tensors(data: Any) -> (Any, bool):
    return as_eager_tensors_(data)[0]

def as_eager_tensors_(data: Any) -> (Any, bool):
    """Convert to eagerpy tensors.
    Parameters
    ----------
    data : (tuple, list, dict, namedtuple, defaultdict)
        data structure to convert

    Returns
    -------
    unwrap : bool
        if True, it means that the tensors have been converted
        to eagerpy tensors.

    """
    if isinstance(data, dict):
        # dict, defaultdict
        if not data:
            return data, None
        keys, res_values, unwrap_values = zip(
            *[(dim,) + as_eager_tensors_(var) for dim, var in data.items()]
        )
        unwrap = True in unwrap_values
        return _dict_as(data, dict(zip(keys, res_values))), unwrap
    elif isinstance(data, (list, tuple)):
        if not data:
            return data, None

        res_values, unwrap_values = zip(*[as_eager_tensors_(var) for var in data])
        unwrap = True in unwrap_values
        try:
            res = type(data)(res_values)
        except TypeError:
            res = type(data)(*res_values)
        return res, unwrap

    elif isinstance(data, ep.Tensor):
        return data, False
    elif isinstance(data, np.datetime64):
        # datetime not managed by ep.tensors
        return data, False
    elif isinstance(data, numbers.Number):
        return data, False
    return ep.astensor(data), True

def as_raw_tensors(data):
    """Convert from eager tensors to raw tensors.

    Parameters
    ----------
    data
        data to convert

    """
    if isinstance(data, dict):
        return _dict_as(data, {dim: as_raw_tensors(var) for dim, var in data.items()})
    elif isinstance(data, (list, tuple)):
        return _tuple_as(data, (as_raw_tensors(var) for var in data))

    if isinstance(data, ep.Tensor):
        return data.raw
    else:
        return data

def restore_tensor_type(data: Any, unwrap: bool) -> Any:
    if unwrap:
        return as_raw_tensors(data)
    else:
        return data

def eager_function(func):
    @wraps(func)
    def eager_func(*args, **kwargs):
        self = None
        if len(func.__qualname__.split(".")) > 1:
            args = list(args)
            self = args.pop(0)
        args, args_unwrap = as_eager_tensors_(args)
        kwargs, kwargs_unwrap = as_eager_tensors_(kwargs)
        unwrap = args_unwrap or kwargs_unwrap
        if self:
            args = [self] + args
        result = func(*args, **kwargs)
        return restore_tensor_type(result, unwrap)

    return eager_func
eserie commented 3 years ago

Another possibility could be to use pytrees implemented in Jax. This should permit to handle more data structures and also to rely on the existing astensors_ implementation using flatten version of the inputs and outputs. However this would create a hard dependency with Jax in eagerpy while currently it's maybe optional.

eserie commented 3 years ago

I propose an implementation based on pytrees in https://github.com/jonasrauber/eagerpy/pull/41. This way to proceed imply few changes like no more register JAXTensor as a pytree datastructure and instead use jax pytree utils for more general datastructures manipulations in eagerpy. The new introduced datastructure convertion functions permit to factorize a bit the method JAXTensor._value_and_grad_fn (for which the initial registration of JAXTensor was tailored)

eserie commented 3 years ago

In fact, I think it's not a good idea to not register JAXTensor in pytrees, it should prevent to have compatibility with jax functionalities. I will restore that in an update of the review.