jonasrauber / eagerpy

PyTorch, TensorFlow, JAX and NumPy — all of them natively using the same code
https://eagerpy.jonasrauber.de
MIT License
693 stars 39 forks source link

add type conversions [feature request] #30

Open PeterMitrano opened 3 years ago

PeterMitrano commented 3 years ago

Converting between things like float32/64/int32/64 would be a really nice addition!

jonasrauber commented 3 years ago

We do have support for astype(dtype) which allows you to do type conversions. But you are right, it's not clear if/how this works with specifc dtypes. So far, we only use astype together with a tensor's dtype attribute; they are always compatible. Maybe that already solves your use case.

PeterMitrano commented 3 years ago

Sure astype exists but it's not useful as is IMHO. Currently I have to write my_int_tensor.astype(torch.float32) which defeats the whole point of being agnostic.

Here's how I think it should behave, not sure if this makes sense

def foo(my_int_tensor: eagerpy.Tensor):
  my_int_tensor.astype(eagerpy.float32)

Here's a simple wrapper that does what I want, but is not a good implementation

import eagerpy
import torch
import tensorflow as tf

def my_astype(x: eagerpy.TensorType, dtype: str):
    if dtype == 'float32':
        if isinstance(x.raw, torch.Tensor):
            specific_dtype = torch.float32
        elif isinstance(x.raw, tf.Tensor):
            specific_dtype = tf.float32
        else:
            pass # more stuff here...
        return x.astype(specific_dtype)
    else:
        pass # more stuff here...

x=torch.tensor([1,2,3])
ex = eagerpy.astensor(x)

print(ex.dtype)
print(my_astype(ex, 'float32').dtype)

x=tf.constant([1,2,3])
ex = eagerpy.astensor(x)

print(ex.dtype)
print(my_astype(ex, 'float32').dtype)
jonasrauber commented 3 years ago

Yes, I fully agree. For float32 specifically, we actually also have a tensor.float32() method (but not yet for the other dtypes).

PeterMitrano commented 3 years ago

oh great, that's actually all I needed at the moment. My apologies for not checking the docs closely enough.

On Wed, Apr 7, 2021 at 11:50 AM Jonas Rauber @.***> wrote:

Yes, I fully agree. For float32 specifically, we actually also have a tensor.float32() method (but not yet for the other dtypes).

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/jonasrauber/eagerpy/issues/30#issuecomment-815023665, or unsubscribe https://github.com/notifications/unsubscribe-auth/AA6TGEWDWV6GJ7TD4ZULOJDTHR5MDANCNFSM42LQ5SOQ .