Open benjaminpope opened 4 years ago
It looks like your code is modifying a Theano tensor in-place:
def with_unit(obj, unit):
"""Decorate a Theano tensor with Astropy units
Args:
obj: The Theano tensor
unit (astropy.Unit): The units for this object
Raises:
TypeError: If the tensor already has units
"""
if hasattr(obj, UNIT_ATTR_NAME):
raise TypeError("{0} already has units".format(repr(obj)))
obj = tt.as_tensor_variable(obj)
setattr(obj, UNIT_ATTR_NAME, unit)
return obj
Unfortunately I don't think we can make this sort of code work in JAX. In-place modification of JAX's array objects is not supported because it breaks JAX's transformations, all of which assume they are being applied to pure functions.
We do want to support use cases like arrays with units in JAX, but we'll have to figure out another way to do that.
I would happily settle for a subclass which had an extra method, which could manipulate these units, even if it didn't do so in-place!
On Thu, Mar 19, 2020 at 5:53 PM Stephan Hoyer notifications@github.com wrote:
It looks like your code is modifying a Theano tensor in-place:
def with_unit(obj, unit): """Decorate a Theano tensor with Astropy units Args: obj: The Theano tensor unit (astropy.Unit): The units for this object Raises: TypeError: If the tensor already has units """ if hasattr(obj, UNIT_ATTR_NAME): raise TypeError("{0} already has units".format(repr(obj))) obj = tt.as_tensor_variable(obj) setattr(obj, UNIT_ATTR_NAME, unit) return obj
Unfortunately I don't think we can make this sort of code work in JAX. In-place modification of JAX's array objects is not supported because it breaks JAX's transformations, all of which assume they are being applied to pure functions.
We do want to support use cases like arrays with units in JAX, but we'll have to figure out another way to do that.
— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_google_jax_issues_2464-23issuecomment-2D601434648&d=DwMCaQ&c=slrrB7dE8n7gBJbeO0g-IQ&r=_gjp43gRuTt_LjqJbH0jcZo0ePCz10Y4KegSkx-Ha0A&m=VwbLYf-_AplQ9wrGkkooVf4tbwe5eIeACbzoyHEHB8E&s=mA37oYiDphcGE0p9x6J_llcUCdhzBDZhXXUaAkb1--g&e=, or unsubscribe https://urldefense.proofpoint.com/v2/url?u=https-3A__github.com_notifications_unsubscribe-2Dauth_ABN6YFM2246WOWLMALR7F3DRIKH53ANCNFSM4LPWLTZA&d=DwMCaQ&c=slrrB7dE8n7gBJbeO0g-IQ&r=_gjp43gRuTt_LjqJbH0jcZo0ePCz10Y4KegSkx-Ha0A&m=VwbLYf-_AplQ9wrGkkooVf4tbwe5eIeACbzoyHEHB8E&s=VkUqTrcK4545EYuoYcZ_RdMksh-xLobaI3MYYEdxFgQ&e= .
-- Dr Benjamin Pope NASA Sagan Fellow Center for Cosmology and Particle Physics // Center for Data Science New York University benjaminpope.github.io
You can subclass DeviceArray to add more slots:
from jax.interpreters.xla import DeviceArray
class MyDeviceArray(DeviceArray):
__slots__ = ["unit"]
Is that like what you have in mind?
Yeah. But I don't quite know how to get this running:
from jax.interpreters.xla import DeviceArray
class newarray(DeviceArray):
__slots__ = ["unit"]
def __new__(self,array):
return array
test = newarray(np.array([10,11]))
print(test.__slots__)
print(type(test))
test.unit = 'test'
['_npy_value', '_device', '_lazy_expr']
<class 'jax.interpreters.xla.DeviceArray'>
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-131-8467a0f54a30> in <module>
11 print(test.__slots__)
12 print(type(test))
---> 13 test.unit = 'test'
AttributeError: 'DeviceArray' object has no attribute 'unit'
I'm trying to add a little wrapper to a DeviceArray to handle astropy.units, which are widely used in astronomy code. I can't figure out how to make this connect with the slotted type of DeviceArray, which doesn't like to be changed!
Here is an example of what I want to do, but for Theano: https://github.com/exoplanet-dev/exoplanet/blob/master/src/exoplanet/units.py