Open eserie opened 3 years ago
Would you consider to have this feature in eagerpy?
Yes, a nice generic decorator that can handle arbitrary number of arguments (and return values) would be great. I am pretty sure I thought about this before, but I cannot recall why I didn't do it.
In addition, we could add the feature that if the input tensors are already eagerpy tensors, then no convertion to raw format should done on the output tensors.
Have you seen the ep.astensor_
and ep.astensors_
functions (with the underscore)? They already do exactly that: https://eagerpy.jonasrauber.de/guide/generic-functions.html (see the examples at the end).
Thanks for your response! No, unfortunately I didn’t have seen the functions astensor and astensors (only astensor). It should definitively be a good starting point! I think I have a POC for a version of that function which could manage more general formats for inputs/outputs. I can make a try for integration in eagerpy and propose a PR in coming days if you are ok.
I show bellow a first POC that I wrote for the wrapper function eager_function
(which is working).
At this stage, it seems not totally trivial to me how to integrate it the code base.
I would appreciate to have a first feedback from this code in order to know if we can go further in this direction.
import numbers
from collections import defaultdict
from functools import wraps
from typing import Any
import eagerpy as ep
def _tuple_as(template, data):
data = list(data)
try:
# list, tuple case
return type(template)(data)
except TypeError:
# named tuple case
return type(template)(*data)
def _dict_as(template, data):
"""Create dictionary like data structure from template object.
Parameters
----------
template
objecti used as template
data
data used to fill the created object.
"""
if isinstance(template, defaultdict):
return type(template)(template.default_factory, data)
return type(template)(data)
def as_eager_tensors(data: Any) -> (Any, bool):
return as_eager_tensors_(data)[0]
def as_eager_tensors_(data: Any) -> (Any, bool):
"""Convert to eagerpy tensors.
Parameters
----------
data : (tuple, list, dict, namedtuple, defaultdict)
data structure to convert
Returns
-------
unwrap : bool
if True, it means that the tensors have been converted
to eagerpy tensors.
"""
if isinstance(data, dict):
# dict, defaultdict
if not data:
return data, None
keys, res_values, unwrap_values = zip(
*[(dim,) + as_eager_tensors_(var) for dim, var in data.items()]
)
unwrap = True in unwrap_values
return _dict_as(data, dict(zip(keys, res_values))), unwrap
elif isinstance(data, (list, tuple)):
if not data:
return data, None
res_values, unwrap_values = zip(*[as_eager_tensors_(var) for var in data])
unwrap = True in unwrap_values
try:
res = type(data)(res_values)
except TypeError:
res = type(data)(*res_values)
return res, unwrap
elif isinstance(data, ep.Tensor):
return data, False
elif isinstance(data, np.datetime64):
# datetime not managed by ep.tensors
return data, False
elif isinstance(data, numbers.Number):
return data, False
return ep.astensor(data), True
def as_raw_tensors(data):
"""Convert from eager tensors to raw tensors.
Parameters
----------
data
data to convert
"""
if isinstance(data, dict):
return _dict_as(data, {dim: as_raw_tensors(var) for dim, var in data.items()})
elif isinstance(data, (list, tuple)):
return _tuple_as(data, (as_raw_tensors(var) for var in data))
if isinstance(data, ep.Tensor):
return data.raw
else:
return data
def restore_tensor_type(data: Any, unwrap: bool) -> Any:
if unwrap:
return as_raw_tensors(data)
else:
return data
def eager_function(func):
@wraps(func)
def eager_func(*args, **kwargs):
self = None
if len(func.__qualname__.split(".")) > 1:
args = list(args)
self = args.pop(0)
args, args_unwrap = as_eager_tensors_(args)
kwargs, kwargs_unwrap = as_eager_tensors_(kwargs)
unwrap = args_unwrap or kwargs_unwrap
if self:
args = [self] + args
result = func(*args, **kwargs)
return restore_tensor_type(result, unwrap)
return eager_func
Another possibility could be to use pytrees implemented in Jax. This should permit to handle more data structures and also to rely on the existing astensors_
implementation using flatten version of the inputs and outputs. However this would create a hard dependency with Jax in eagerpy while currently it's maybe optional.
I propose an implementation based on pytrees in https://github.com/jonasrauber/eagerpy/pull/41.
This way to proceed imply few changes like no more register JAXTensor
as a pytree datastructure and instead use jax pytree utils for more general datastructures manipulations in eagerpy. The new introduced datastructure convertion functions permit to factorize a bit the method JAXTensor._value_and_grad_fn
(for which the initial registration of JAXTensor was tailored)
In fact, I think it's not a good idea to not register JAXTensor in pytrees, it should prevent to have compatibility with jax functionalities. I will restore that in an update of the review.
In order to simplify the writting of universal functions it could be great to have a decorator function which hide the technical part of the code (convertion of input and output of the wrapped function/method). For example, the code:
would become:
In addition, we could add the feature that if the input tensors are already eagerpy tensors, then no convertion to raw format should done on the output tensors.
I wrote a prototype of such a decorator function. It should not work on any type of arguments and so its usage would require that the wrapped function has a rather "simple" signature (with args and kwargs constituted of tensors or nested containers with tensors on leaves: dict, list, tuple or namedtuple like containers).
Would you consider to have this feature in eagerpy?