enthought / scimath

Other
68 stars 16 forks source link

UnitArray __rmul__ and __rdiv__ drop units when isinstance(other, numpy.generic) #6

Open rupertnash opened 12 years ago

rupertnash commented 12 years ago

I have discovered that scimath.units is dropping the units from the result when the left hand operand to * or / is a numpy scalar. Simple test cases:

>>> import numpy
>>> from scimath.units.api import UnitArray
>>> from scimath.units import SI
>>> a = UnitArray(1., units=SI.meter)
>>> b = numpy.array(1.)
>>> a*b # This will work
UnitArray(1.0, units='1.0*m')
>>> b*a # This also works
UnitArray(1.0, units='1.0*m')
>>> c = numpy.sqrt(1.)
>>> type(c)
<type 'numpy.float64'>
>>> a*c # this, i.e. __mul__ works
UnitArray(1.0, units='1.0*m')
>>> c*a # __rmul__ fails!
UnitArray(1.0, units='None')
>>> c / a # __rdiv__ shows the same bug
UnitArray(1.0, units='None')

It is not immediately clear to me how to fix this.

mdickinson commented 12 years ago

So the issue is really in numpy: if c is of type numpy.float64, then c.__mul__ incorrectly accepts a UnitArray instance and returns a result without units, rather than delegating responsibility for the multiplication to UnitArray.__rmul__ as we'd like it to.

The only way I see to fix this in scimath.units would be to not have UnitArray inherit from ndarray any more, but that's a bit of a drastic move that would almost certainly break a lot of existing code. [From a purity point of view, I think that it's actually the right move: a UnitArray is a product of an array with a unit, and for me it's a stretch to regard that as an ndarray in its own right. I remember seeing similar difficulties from people trying to create a 'Money' class that inherited from the standard library 'Decimal' type.]

rupertnash commented 12 years ago

I think this isn't the whole story though. I tried running c*a under pdb and the __array_wrap__ method is called afterwards. My understanding of the numpy API is that this method is called on the object involved which has the highest __array_priority__ attribute. So this method might be suitable for sorting out the units?