google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.95k stars 2.75k forks source link

add attribute to devicearray? #2464

Open benjaminpope opened 4 years ago

benjaminpope commented 4 years ago

I'm trying to add a little wrapper to a DeviceArray to handle astropy.units, which are widely used in astronomy code. I can't figure out how to make this connect with the slotted type of DeviceArray, which doesn't like to be changed!

Here is an example of what I want to do, but for Theano: https://github.com/exoplanet-dev/exoplanet/blob/master/src/exoplanet/units.py

shoyer commented 4 years ago

It looks like your code is modifying a Theano tensor in-place:

def with_unit(obj, unit):
    """Decorate a Theano tensor with Astropy units
    Args:
        obj: The Theano tensor
        unit (astropy.Unit): The units for this object
    Raises:
        TypeError: If the tensor already has units
    """
    if hasattr(obj, UNIT_ATTR_NAME):
        raise TypeError("{0} already has units".format(repr(obj)))
    obj = tt.as_tensor_variable(obj)
    setattr(obj, UNIT_ATTR_NAME, unit)
    return obj

Unfortunately I don't think we can make this sort of code work in JAX. In-place modification of JAX's array objects is not supported because it breaks JAX's transformations, all of which assume they are being applied to pure functions.

We do want to support use cases like arrays with units in JAX, but we'll have to figure out another way to do that.

benjaminpope commented 4 years ago

I would happily settle for a subclass which had an extra method, which could manipulate these units, even if it didn't do so in-place!

On Thu, Mar 19, 2020 at 5:53 PM Stephan Hoyer notifications@github.com wrote:

It looks like your code is modifying a Theano tensor in-place:

def with_unit(obj, unit): """Decorate a Theano tensor with Astropy units Args: obj: The Theano tensor unit (astropy.Unit): The units for this object Raises: TypeError: If the tensor already has units """ if hasattr(obj, UNIT_ATTR_NAME): raise TypeError("{0} already has units".format(repr(obj))) obj = tt.as_tensor_variable(obj) setattr(obj, UNIT_ATTR_NAME, unit) return obj

Unfortunately I don't think we can make this sort of code work in JAX. In-place modification of JAX's array objects is not supported because it breaks JAX's transformations, all of which assume they are being applied to pure functions.

We do want to support use cases like arrays with units in JAX, but we'll have to figure out another way to do that.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_google_jax_issues_2464-23issuecomment-2D601434648&d=DwMCaQ&c=slrrB7dE8n7gBJbeO0g-IQ&r=_gjp43gRuTt_LjqJbH0jcZo0ePCz10Y4KegSkx-Ha0A&m=VwbLYf-_AplQ9wrGkkooVf4tbwe5eIeACbzoyHEHB8E&s=mA37oYiDphcGE0p9x6J_llcUCdhzBDZhXXUaAkb1--g&e=, or unsubscribe https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_notifications_unsubscribe-2Dauth_ABN6YFM2246WOWLMALR7F3DRIKH53ANCNFSM4LPWLTZA&d=DwMCaQ&c=slrrB7dE8n7gBJbeO0g-IQ&r=_gjp43gRuTt_LjqJbH0jcZo0ePCz10Y4KegSkx-Ha0A&m=VwbLYf-_AplQ9wrGkkooVf4tbwe5eIeACbzoyHEHB8E&s=VkUqTrcK4545EYuoYcZ_RdMksh-xLobaI3MYYEdxFgQ&e= .

-- Dr Benjamin Pope NASA Sagan Fellow Center for Cosmology and Particle Physics // Center for Data Science New York University benjaminpope.github.io

mattjj commented 4 years ago

You can subclass DeviceArray to add more slots:

from jax.interpreters.xla import DeviceArray

class MyDeviceArray(DeviceArray):
  __slots__ = ["unit"]

Is that like what you have in mind?

benjaminpope commented 4 years ago

Yeah. But I don't quite know how to get this running:

from jax.interpreters.xla import DeviceArray

class newarray(DeviceArray):
    __slots__ = ["unit"]
    def __new__(self,array):
        return array

test = newarray(np.array([10,11]))
print(test.__slots__)
print(type(test))
test.unit = 'test'

['_npy_value', '_device', '_lazy_expr']
<class 'jax.interpreters.xla.DeviceArray'>
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-131-8467a0f54a30> in <module>
     11 print(test.__slots__)
     12 print(type(test))
---> 13 test.unit = 'test'

AttributeError: 'DeviceArray' object has no attribute 'unit'