Unidata / MetPy

MetPy is a collection of tools in Python for reading, visualizing and performing calculations with weather data.
https://unidata.github.io/MetPy/
BSD 3-Clause "New" or "Revised" License
1.25k stars 414 forks source link

Unexpected DimensionalityError for quantity coming from dask array #2945

Open gerritholl opened 1 year ago

gerritholl commented 1 year ago

What went wrong?

I have an xarray.DataArray wrapping a dask array. The DataArray has attributes including units, so I use the da.metpy.quantify() accessor method to turn it into a pint quantity that metpy can work with. However, despite the result being a quantity, metpy calculations fails with a DimensionalityError.

Operating System

Linux

Version

1.4.0

Python Version

3.11.0

Code to Reproduce

import metpy.calc
import xarray as xr
import dask.array as da

t1 = xr.DataArray(da.array([300]), attrs={"units": "K"}).metpy.quantify().compute().item()
t2 = t1.magnitude * t1.units
print(t1, t2, t1==t2)
print(metpy.calc.saturation_vapor_pressure(t2))  # works
print(metpy.calc.saturation_vapor_pressure(t1))  # fails

Errors, Traceback, and Logs

300 kelvin 300 kelvin True
3534.519666889136 pascal
Traceback (most recent call last):
  File "/data/gholl/checkouts/protocode/mwe/metpy-dask-dimensionerror.py", line 9, in <module>
    print(metpy.calc.saturation_vapor_pressure(t1))  # fails
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/gholl/mambaforge/envs/py311/lib/python3.11/site-packages/metpy/xarray.py", line 1329, in wrapper
    result = func(*bound_args.args, **bound_args.kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/gholl/mambaforge/envs/py311/lib/python3.11/site-packages/metpy/units.py", line 376, in wrapper
    result = func(*bound_args.args, **bound_args.kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/gholl/mambaforge/envs/py311/lib/python3.11/site-packages/metpy/calc/thermo.py", line 1302, in saturation_vapor_pressure
    17.67 * (temperature - 273.15) / (temperature - 29.65)
             ~~~~~~~~~~~~^~~~~~~~
  File "/data/gholl/mambaforge/envs/py311/lib/python3.11/site-packages/pint/facets/plain/quantity.py", line 971, in __sub__
    return self._add_sub(other, operator.sub)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/gholl/mambaforge/envs/py311/lib/python3.11/site-packages/pint/facets/plain/quantity.py", line 102, in wrapped
    return f(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/gholl/mambaforge/envs/py311/lib/python3.11/site-packages/pint/facets/plain/quantity.py", line 858, in _add_sub
    raise DimensionalityError(self._units, "dimensionless")
pint.errors.DimensionalityError: Cannot convert from 'kelvin' to 'dimensionless'
gerritholl commented 1 year ago

I found that isinstance(t1, metpy.units.units.Quantity)) is False, whereas the same is true for t2.

This leads to a divergence in behaviour in metpy.units._mutate_arguments:

https://github.com/Unidata/MetPy/blob/feeca2672d99656834684770ad0cf7c166040110/src/metpy/units.py#L204-L217

What I don't know yet, is why t1 is not an instance of Quantity, and whether this is due to a problem in metpy, in pint, or somewhere else.

gerritholl commented 1 year ago

If I instead calculate t3 = xr.DataArray(da.array([300]), attrs={"units": "K"}).metpy.quantify().data[0].compute(), changing the order of operations, then t3 belongs to the correct unit registry and the calculation is successful.

sgdecker commented 1 year ago

I'm seeing the same thing in a new environment with all the latest and greatest from conda-forge. Code:

import xarray as xr
import intake
import metpy.calc as mpcalc

catalog_url = ('https://raw.githubusercontent.com/NCAR/intake-esm-datastore'
               '/master/catalogs/pangeo-cmip6.json')
catalog = intake.open_esm_datastore(catalog_url)

ncar_press = catalog.search(experiment_id='ssp585', table_id='6hrLev',
                            variable_id='ps', institution_id='NCAR')
ds_press = ncar_press.to_dataset_dict()

sfcpress = ds_press['ScenarioMIP.NCAR.CESM2.ssp585.6hrLev.gn']
sfcpress = sfcpress.ps.squeeze()
press = sfcpress.sel(time='2015-01-01 00:00:00').metpy.quantify()

ncar_temp = catalog.search(experiment_id='ssp585', table_id='6hrLev',
                           variable_id='ta', institution_id='NCAR')
ds_temp = ncar_temp.to_dataset_dict()

temp = ds_temp['ScenarioMIP.NCAR.CESM2.ssp585.6hrLev.gn']
temp = temp.ta.squeeze()
temp = temp.sel(time='2015-01-01 00:00:00').metpy.quantify()

# Correct Pressure
correct = temp['lev']*press

ncar_hum = catalog.search(experiment_id='ssp585', table_id='6hrLev',
                          variable_id='hus', institution_id='NCAR')
ds_hum = ncar_hum.to_dataset_dict()

hum = ds_hum['ScenarioMIP.NCAR.CESM2.ssp585.6hrLev.gn']
hum = hum.hus.squeeze()
hum = hum.sel(time='2015-01-01 00:00:00').metpy.quantify()

dewpt = mpcalc.dewpoint_from_specific_humidity(correct, temp, hum).rename('dewpt')

ds = xr.merge((temp, dewpt))
test = float(f"{ds.lev[0].values:.20f}")

cape_array = xr.zeros_like(ds.ta)
cape_array = cape_array.sel(lev=test, drop=True).rename('cape')

Tarray = ds['ta'].compute()
Tdarray = ds['dewpt'].compute()
correct1 = correct.compute()

for j in range(len(ds.lat)):
    for i in range(len(ds.lon)):
        p = correct1.isel(lon=i, lat=j, time=0)
        T = Tarray.isel(lon=i, lat=j, time=0)
        Td = Tdarray.isel(lon=i, lat=j, time=0)
        print(p)
        print(T)
        print(Td)
        newp, newT, newTd, prof = mpcalc.parcel_profile_with_lcl(p, T, Td)

Output:

--> The keys in the returned dictionary of datasets are constructed as follows:
    'activity_id.institution_id.source_id.experiment_id.table_id.grid_label'
 |███████████████████████████████████████████████████████████████████████████████████████████████| 100.00% [1/1 00:11<00:00]
--> The keys in the returned dictionary of datasets are constructed as follows:
    'activity_id.institution_id.source_id.experiment_id.table_id.grid_label'
 |███████████████████████████████████████████████████████████████████████████████████████████████| 100.00% [1/1 00:10<00:00]
--> The keys in the returned dictionary of datasets are constructed as follows:
    'activity_id.institution_id.source_id.experiment_id.table_id.grid_label'
/home/decker/local/miniconda3/envs/meghan/lib/python3.11/site-packages/dask/core.py:119: RuntimeWarning: invalid value encountered in log
  return func(*(_execute_task(a, cache) for a in args))
<xarray.DataArray (lev: 32)>
<Quantity([67628.22052148 66522.33589235 65238.66960699 63788.26552446
 62183.41409796 60437.49194197 58564.7572183  55929.52483464
 52014.8804178  47108.1051129  41547.52385727 35749.77790199
 30387.88870481 25830.19705656 21956.085445   18663.02691606
 15863.87435834 13484.54943377 11462.08546616  9742.95803943
  8281.6715776   7039.5552028   5983.73588945  5025.05205039
  4191.72462132  3521.06778979  2943.02404685  2447.64551455
  1676.96380098   978.19508298   517.47618133   248.24904369], 'pascal')>
Coordinates:
  * lev             (lev) float64 0.9926 0.9763 0.9575 ... 0.007595 0.003643
    member_id       <U9 'r11i1p1f1'
    dcpp_init_year  float64 nan
    lat             float64 -90.0
    lon             float64 0.0
    time            object 2015-01-01 00:00:00
<xarray.DataArray 'ta' (lev: 32)>
<Quantity([247.763   247.33391 247.54343 247.56554 246.91426 245.79741 244.51468
 243.19357 242.25963 240.0033  236.11937 230.23816 224.3636  219.31567
 216.107   214.95854 214.7533  216.77162 221.03688 224.03021 225.70203
 227.39996 229.06297 230.52847 231.968   233.56319 235.15633 237.39421
 239.93338 244.78337 254.57843 273.04788], 'kelvin')>
Coordinates:
    lat             float64 -90.0
  * lev             (lev) float64 0.9926 0.9763 0.9575 ... 0.007595 0.003643
    lon             float64 0.0
    time            object 2015-01-01 00:00:00
    member_id       <U9 'r11i1p1f1'
    dcpp_init_year  float64 nan
Attributes: (12/18)
    cell_measures:  area: areacella
    cell_methods:   area: mean time: point
    comment:        T
    description:    Air Temperature
    frequency:      6hrPt
    id:             ta
    ...             ...
    time:           time1
    time_label:     time-point
    time_title:     Instantaneous value (i.e. synoptic or time-step value)
    title:          Air Temperature
    type:           real
    variable_id:    ta
<xarray.DataArray 'dewpt' (lev: 32)>
<Quantity([ -27.27566094  -26.89692919  -26.97390008  -29.1442997   -30.86587048
  -31.55465684  -32.14347173  -33.56125249  -39.71800697  -42.99317581
  -45.52171239  -47.50503843  -55.95619985  -61.78036008  -66.41280927
  -69.37795601  -71.57499017  -77.93992385  -83.061864    -85.45568483
  -87.47617172  -89.34300552  -89.17087544  -90.02805716  -90.84138346
  -91.42555242  -92.29315775  -93.28479238  -95.23330666  -97.72190628
 -101.82599772           nan], 'degree_Celsius')>
Coordinates:
    lat             float64 -90.0
  * lev             (lev) float64 0.9926 0.9763 0.9575 ... 0.007595 0.003643
    lon             float64 0.0
    time            object 2015-01-01 00:00:00
    member_id       <U9 'r11i1p1f1'
    dcpp_init_year  float64 nan
Traceback (most recent call last):
  File "/chariton/decker/test/students/NCAR.py", line 57, in <module>
    newp, newT, newTd, prof = mpcalc.parcel_profile_with_lcl(p, T, Td)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/decker/local/miniconda3/envs/meghan/lib/python3.11/site-packages/metpy/xarray.py", line 1329, in wrapper
    result = func(*bound_args.args, **bound_args.kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/decker/local/miniconda3/envs/meghan/lib/python3.11/site-packages/metpy/units.py", line 320, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/decker/local/miniconda3/envs/meghan/lib/python3.11/site-packages/metpy/calc/thermo.py", line 1068, in parcel_profile_with_lcl
    p_l, p_lcl, p_u, t_l, t_lcl, t_u = _parcel_profile_helper(pressure, temperature[0],
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/decker/local/miniconda3/envs/meghan/lib/python3.11/site-packages/metpy/calc/thermo.py", line 1172, in _parcel_profile_helper
    press_lcl, temp_lcl = lcl(pressure[0], temperature, dewpoint)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/decker/local/miniconda3/envs/meghan/lib/python3.11/site-packages/metpy/xarray.py", line 1329, in wrapper
    result = func(*bound_args.args, **bound_args.kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/decker/local/miniconda3/envs/meghan/lib/python3.11/site-packages/metpy/units.py", line 376, in wrapper
    result = func(*bound_args.args, **bound_args.kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/decker/local/miniconda3/envs/meghan/lib/python3.11/site-packages/metpy/calc/thermo.py", line 464, in lcl
    w = mixing_ratio._nounit(saturation_vapor_pressure._nounit(dewpoint), pressure)
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/decker/local/miniconda3/envs/meghan/lib/python3.11/site-packages/metpy/calc/thermo.py", line 1302, in saturation_vapor_pressure
    17.67 * (temperature - 273.15) / (temperature - 29.65)
             ~~~~~~~~~~~~^~~~~~~~
  File "/home/decker/local/miniconda3/envs/meghan/lib/python3.11/site-packages/pint/facets/plain/quantity.py", line 971, in __sub__
    return self._add_sub(other, operator.sub)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/decker/local/miniconda3/envs/meghan/lib/python3.11/site-packages/pint/facets/plain/quantity.py", line 102, in wrapped
    return f(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/decker/local/miniconda3/envs/meghan/lib/python3.11/site-packages/pint/facets/plain/quantity.py", line 858, in _add_sub
    raise DimensionalityError(self._units, "dimensionless")
pint.errors.DimensionalityError: Cannot convert from 'degree_Celsius' to 'dimensionless'

I see the comment about changing the order of operations as a possible workaround, but am not sure how to apply that to my case.

For the record, I'm running

I tried the same code in an older environment but got a different error.

sgdecker commented 1 year ago

I found the workaround for my code: Do .compute() before .metpy.quantify() rather than later.

dopplershift commented 1 year ago

I think the order of operations points the way to figure out who's at fault, but fundamentally, the problem is somewhere in the Dask + Xarray + Pint integration (call me "shocked"). You can see this without any calculation:

isinstance(t1, t1._REGISTRY.Quantity)

results in False. Somebody is doing something bad.

dopplershift commented 1 year ago

Turns out the problem was in Pint. I opened hgrecco/pint#1722 to fix it, so things should work better once that's merged and released. I'll note that this was needed for me to make the example from @sgdecker work, but not @gerritholl.

Another note: y'all are using way too much .metpy.quantify(). MetPy will call that for you when you pass in xarray data to our calculation. So @gerritholl, I can make your example work simply (and it works today) with just (and I'd call this the ideal and expected way):

t1 = xr.DataArray(da.array([300]), attrs={"units": "K"})
metpy.calc.saturation_vapor_pressure(t1).compute()

@sgdecker The only call .metpy.quantify() you need is:

press = sfcpress.sel(time='2015-01-01 00:00:00').metpy.quantify()

That's because you manually mutliply the lev and press DataArray instances, and by default xarray will drop metadata when you do that; calling .metpy.quantify() converts from having units information as metadata to stored in a Quantity() instance, which then keeps the information across operations. The other manual calls to .metpy.quantify() are superfluous.

sgdecker commented 1 year ago

@dopplershift Glad to hear you have a fix, and thanks for the .metpy.quantify() tips. I knew the one .metpy.quantify() was needed for the sigma-to-pressure computation. The others I think we added at some point in an attempt to avoid this or some other DimensionalityError along the way, but obviously that wasn't the real problem.

gerritholl commented 1 year ago

Good to hear I can avoid using .metpy.quantify() entirely. The reason I used it in the first place is because I have dask arrays, but not all metpy functions are dask-aware; for example, parcel_profile is very slow with dask arrays as it calls .compute() 18 times. But if I compute the profiles before passing them to parcel_profile, then parcel_profile fails with the aforementioned DimensionalityError. Trying to shorten that to an MCVE led to the post in the original question, but in my real use case I get the problem even if not calling .metpy.quantify(), at least not directly/explicitly.

gerritholl commented 1 year ago

My workaround is a little harder as I'm trying to replace

parcel_profile(p[::-1], t[-1], d[-1]) (works, but computes the dask arrays 18 (!) times), with

dask.delayed(parcel_profile)(p[::-1], t[-1], d[-1]).compute()

which fails with the aforementioned DimensionalityError. A workaround seems hard without pre-computing (which I don't want to do), but it works with 1 compute if I use the branch behind https://github.com/hgrecco/pint/pull/1722 :+1:

dopplershift commented 1 year ago

Thanks for letting us know @gerritholl. The fix has been merged upstream in pint, so should be fixed with the next release.

Does that fully solve all of the issues identified in this issue?