jonasrauber / eagerpy

PyTorch, TensorFlow, JAX and NumPy — all of them natively using the same code
https://eagerpy.jonasrauber.de
MIT License
693 stars 39 forks source link

Python Scalars Support #32

Open nmichlo opened 3 years ago

nmichlo commented 3 years ago

Feature Request

Would is be possible to support python scalars?

eg.

num_float = ep.astensor(1.0)
num_int = ep.astensor(42)

Why?

This would be helpful for implementing generic functions, for example:

def get_kernel_size(sigma=1.0, trunc=4.0):
    sigma, trunc = ep.astensors(sigma, trunc)
    radius = (sigma * trunc + 0.5).astype(int)
    return (2 * radius + 1).raw

This could be called with default values as well as tensors.

t_sigma = torch.abs(torch.randn(10))
n_trunc = np.abs(np.random.randn(10))

get_kernel_size(sigma=1.0, trunc=4.0)
get_kernel_size(sigma=t_sigma, trunc=4.0)
get_kernel_size(sigma=1.0, trunc=n_trunc)
jonasrauber commented 3 years ago

Nice example. I see why this can be useful. It might be possible to add a simple wrapper class for python scalars that automatically gets unwrapped when it interacts with any tensor.

nmichlo commented 3 years ago

That sounds ideal.

I am not sure about the internals of the library, but I imagine this could have implications for the various helper methods and functions? I am not sure how these would or could be handled, or if it would result in any special cases?

jonasrauber commented 3 years ago

Yes, I think there are quite a few cases to consider and some corner cases might be challenging, but it should be doable.

For now, this workaround should be fine (if you assume that always at least one of the arguments is a tensor):

def get_kernel_size(sigma=1.0, trunc=4.0):
    if isinstance(sigma, int) or isinstance(sigma, float):
        trunc = ep.astensor(trunc)
        sigma = ep.zeros(trunc, ()) + sigma
    elif isinstance(trunc, int) or isinstance(trunc, float):
        sigma = ep.astensor(sigma)
        trunc = ep.zeros(sigma, ()) + trunc
    else:
        sigma, trunc = ep.astensors(sigma, trunc)
    radius = (sigma * trunc + 0.5).astype(int)
    return (2 * radius + 1).raw

(I didn't test this)

You can introduce helper functions to simplify it further.

In fact, we could actually integrate all of the above into ep.astensors. It works as long as one of the arguments as actually a tensor (i.e., not all of them can be ints and floats).