google-research / weatherbench2

A benchmark for the next generation of data-driven global weather models.
https://weatherbench2.readthedocs.io
Apache License 2.0
439 stars 43 forks source link

Possible bug with xarray dataset selection within beam #173

Closed jdwillard19 closed 3 months ago

jdwillard19 commented 4 months ago

I am running "evaluate_with_beam" and it doesn't appear to be able to select variables in the observations dataset.

This is the line causing the error

If I load the same dataset and execute the line in question outside of the beam context using the same python container environment, it works just fine.

Also, if I go into the _impose_data_selection() and manually do

print("Variables in dataset:", list(dataset.data_vars.keys()))

it outputs the same list I am trying to select

['10m_u_component_of_wind', '10m_v_component_of_wind', '2m_temperature', 'geopotential', 'mean_sea_level_pressure', 'specific_humidity', 'surface_pressure', 'temperature', 'total_column_water_vapour', 'u_component_of_wind', 'v_component_of_wind']

code snippet

    def score_deterministic(self):
        import weatherbench2
        from weatherbench2.metrics import MSE, ACC, MAE, Bias
        from weatherbench2 import config as wb2_config
        from weatherbench2 import evaluation as wb2_evaluation
        if self.log_to_screen:
            logging.info("Beginning scoring with WB2....")
        #define WB2 configs
        paths = wb2_config.Paths(
                            forecast=self.forecast_output_path,
                            obs=self.obs_path,
                            output_dir=self.inf_dir,  
                            climatology=self.climatology_path
                       )
        selection = wb2_config.Selection(
                                    variables=self.score_variables,
                                    levels=self.score_levels,
                                    time_slice=slice(self.score_start_date, self.score_end_date))

        data_config = wb2_config.Data(selection=selection, paths=paths)

        climatology = None
        if self.climatology_path:
            climatology = xr.open_zarr(self.climatology_path)

        metrics = {}
        if 'mse' in self.score_metrics:
            metrics['mse'] = MSE()
        if 'acc' in self.score_metrics:
            if climatology is not None:
                metrics['acc'] = ACC(climatology=climatology)
            else:
                raise ValueError("Climatology path must be provided if 'acc' metric is specified.")
        if 'mae' in self.score_metrics:
            metrics['mae'] = MAE()
        if 'bias' in self.score_metrics:
            metrics['bias'] = Bias()

        regions = {}
        if 'global' in self.score_regions:
            regions['global'] = weatherbench2.regions.SliceRegion()
        if 'tropics' in self.score_regions:
            regions['tropics'] = weatherbench2.regions.SliceRegion(lat_slice=slice(-20, 20))
        if 'extra-tropics' in self.score_regions:
            regions['extra-tropics'] = weatherbench2.regions.ExtraTropicalRegion(),

        # Create the eval_configs dictionary
        eval_config = {
            'deterministic': wb2_config.Eval(metrics=metrics,
                                                       regions=regions,
                                                       evaluate_persistence=self.evaluate_persistence,
                                                       evaluate_climatology=self.evaluate_climatology)
        }

        if self.use_beam:
            direct_runner_options = [
                    f'--direct_num_workers={self.direct_num_workers}',
                    '--direct_running_mode=multi_processing',
            ]

            # Combine existing argv with the new DirectRunner options
            argv = []
            argv.extend(direct_runner_options)
            wb2_evaluation.evaluate_with_beam(
                data_config,
                eval_config,
                runner='DirectRunner',
                input_chunks={'init_time': 1, 'lead_time': 1},
                fanout=self.fanout,
                argv=argv
            )
        else:
            wb2_evaluation.evaluate_in_memory(data_config, eval_config)

Stack trace

Traceback (most recent call last): File "/global/u2/j/jwillard/healpixnat/inference.py", line 22, in run() File "/usr/local/lib/python3.10/dist-packages/hydra/main.py", line 94, in decorated_main _run_hydra( File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 394, in _run_hydra _run_app( File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 457, in _run_app run_and_report( File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 223, in run_and_report raise ex File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 220, in run_and_report return func() File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 458, in lambda: hydra.run( File "/usr/local/lib/python3.10/dist-packages/hydra/internal/hydra.py", line 132, in run = ret.return_value File "/usr/local/lib/python3.10/dist-packages/hydra/core/utils.py", line 260, in return_value raise self._return_value File "/usr/local/lib/python3.10/dist-packages/hydra/core/utils.py", line 186, in run_job ret.return_value = task_function(task_cfg) File "/global/u2/j/jwillard/healpixnat/inference.py", line 13, in run inferencer.launch() File "/global/u2/j/jwillard/healpixnat/utils/inferencer.py", line 108, in launch self.build_and_run() File "/global/u2/j/jwillard/healpixnat/utils/inferencer.py", line 140, in build_and_run self.inference() File "/global/u2/j/jwillard/healpixnat/utils/inferencer.py", line 339, in inference self.score_deterministic() File "/global/u2/j/jwillard/healpixnat/utils/inferencer.py", line 258, in score_deterministic wb2_evaluation.evaluate_with_beam( File "/global/u2/j/jwillard/.local/perlmutter/dasrepo_pharring_deepspeed_pytorch_24.05/lib/python3.10/site-packages/weatherbench2/evaluation.py", line 824, in evaluate_with_beam root File "/global/u2/j/jwillard/.local/perlmutter/dasrepo_pharring_deepspeed_pytorch_24.05/lib/python3.10/site-packages/apache_beam/transforms/ptransform.py", line 1110, in ror return self.transform.ror(pvalueish, self.label) File "/global/u2/j/jwillard/.local/perlmutter/dasrepo_pharring_deepspeed_pytorch_24.05/lib/python3.10/site-packages/apache_beam/transforms/ptransform.py", line 623, in ror result = p.apply(self, pvalueish, label) File "/global/u2/j/jwillard/.local/perlmutter/dasrepo_pharring_deepspeed_pytorch_24.05/lib/python3.10/site-packages/apache_beam/pipeline.py", line 679, in apply return self.apply(transform, pvalueish) File "/global/u2/j/jwillard/.local/perlmutter/dasrepo_pharring_deepspeed_pytorch_24.05/lib/python3.10/site-packages/apache_beam/pipeline.py", line 732, in apply pvalueish_result = self.runner.apply(transform, pvalueish, self._options) File "/global/u2/j/jwillard/.local/perlmutter/dasrepo_pharring_deepspeed_pytorch_24.05/lib/python3.10/site-packages/apache_beam/runners/runner.py", line 203, in apply return self.apply_PTransform(transform, input, options) File "/global/u2/j/jwillard/.local/perlmutter/dasrepo_pharring_deepspeed_pytorch_24.05/lib/python3.10/site-packages/apache_beam/runners/runner.py", line 207, in apply_PTransform return transform.expand(input) File "/global/u2/j/jwillard/.local/perlmutter/dasrepo_pharring_deepspeed_pytorch_24.05/lib/python3.10/site-packages/weatherbench2/evaluation.py", line 773, in expand forecast, truth, climatology = open_forecast_and_truth_datasets( File "/global/u2/j/jwillard/.local/perlmutter/dasrepo_pharring_deepspeed_pytorch_24.05/lib/python3.10/site-packages/weatherbench2/evaluation.py", line 376, in open_forecast_and_truth_datasets obs_all_times = _impose_data_selection( File "/global/u2/j/jwillard/.local/perlmutter/dasrepo_pharring_deepspeed_pytorch_24.05/lib/python3.10/site-packages/weatherbench2/evaluation.py", line 170, in _impose_data_selection dataset = dataset[sel_variables].sel( File "/global/u2/j/jwillard/.local/perlmutter/dasrepo_pharring_deepspeed_pytorch_24.05/lib/python3.10/site-packages/xarray/core/dataset.py", line 1484, in getitem return self._construct_dataarray(key) File "/global/u2/j/jwillard/.local/perlmutter/dasrepo_pharring_deepspeed_pytorch_24.05/lib/python3.10/site-packages/xarray/core/dataset.py", line 1395, in _constructdataarray , name, variable = _get_virtual_variable(self._variables, name, self.dims) File "/global/u2/j/jwillard/.local/perlmutter/dasrepo_pharring_deepspeed_pytorch_24.05/lib/python3.10/site-packages/xarray/core/dataset.py", line 192, in _get_virtual_variable raise KeyError(key) KeyError: ['geopotential', 'temperature', 'u_component_of_wind', 'v_component_of_wind', 'specific_humidity', '2m_temperature', '10m_u_component_of_wind', '10m_v_component_of_wind', 'mean_sea_level_pressure']

Environment

Package Version Editable project location


absl-py 2.1.0 aiobotocore 2.13.0 aiohttp 3.9.5 aioitertools 0.11.0 aiosignal 1.3.1 annotated-types 0.6.0 antlr4-python3-runtime 4.9.3 apache-beam 2.57.0 apex 0.1 argon2-cffi 23.1.0 argon2-cffi-bindings 21.2.0 asciitree 0.3.3 astropy 6.1.0 astropy-iers-data 0.2024.6.10.0.30.47 asttokens 2.4.1 astunparse 1.6.3 async-timeout 4.0.3 attrs 23.2.0 audioread 3.0.1 beautifulsoup4 4.12.3 black 24.4.2 bleach 6.1.0 blis 0.7.11 bokeh 3.5.0 botocore 1.34.106 cachetools 5.3.3 cads-api-client 1.0.3 Cartopy 0.23.0 catalogue 2.0.10 cdsapi 0.7.0 certifi 2024.2.2 cffi 1.16.0 cftime 1.6.4 charset-normalizer 3.3.2 click 8.1.7 cloudpathlib 0.16.0 cloudpickle 2.2.1 cmake 3.29.2 comm 0.2.2 confection 0.1.4 contourpy 1.2.1 crcmod 1.7 cuda-python 12.4.0 cudf 24.4.0 cugraph 24.4.0 cugraph-dgl 24.4.0 cugraph-equivariant 24.4.0 cugraph-pyg 24.4.0 cugraph-service-client 24.4.0 cugraph-service-server 24.4.0 cuml 24.4.0 cupy-cuda12x 13.0.0 cycler 0.12.1 cymem 2.0.8 Cython 3.0.10 dask 2024.1.1 dask-cuda 24.4.0 dask-cudf 24.4.0 dask-expr 0.4.0 debugpy 1.8.1 decorator 5.1.1 deepspeed 0.14.3 defusedxml 0.7.1 dgl 2.2.1 dill 0.3.1.1 distributed 2024.1.1 dm-tree 0.1.8 dnspython 2.6.1 docker-pycreds 0.4.0 docopt 0.6.2 earth2-grid 2024.5.2 einops 0.8.0 entrypoints 0.4 exceptiongroup 1.2.1 execnet 2.1.1 executing 2.0.1 expecttest 0.1.3 fastavro 1.9.5 fasteners 0.19 fastjsonschema 2.19.1 fastrlock 0.8.2 filelock 3.14.0 flash-attn 2.4.2 fonttools 4.51.0 frozenlist 1.4.1 fsspec 2024.6.0 gast 0.5.4 gitdb 4.0.11 GitPython 3.1.43 gnureadline 8.2.10 google-auth 2.29.0 google-auth-oauthlib 0.4.6 grpcio 1.64.1 h5py 3.11.0 hdfs 2.7.3 healpy 1.17.1 hjson 3.1.0 httplib2 0.22.0 huggingface-hub 0.23.3 hydra-core 1.3.2 hypothesis 5.35.1 idna 3.7 igraph 0.11.4 imageio 2.34.1 immutabledict 4.2.0 importlib_metadata 7.1.0 iniconfig 2.0.0 intel-openmp 2021.4.0 ipykernel 6.29.4 ipython 8.21.0 ipython-genutils 0.2.0 jax 0.4.30 jaxlib 0.4.30 jedi 0.19.1 Jinja2 3.1.3 jmespath 1.0.1 joblib 1.4.0 Js2Py 0.74 json5 0.9.25 jsonpickle 3.2.2 jsonschema 4.22.0 jsonschema-specifications 2023.12.1 jupyter_client 8.6.1 jupyter_core 5.7.2 jupyter-tensorboard 0.2.0 jupyterlab 2.3.2 jupyterlab_pygments 0.3.0 jupyterlab-server 1.2.0 jupytext 1.16.1 kiwisolver 1.4.5 kvikio 24.4.0 langcodes 3.4.0 language_data 1.2.0 lark 1.1.9 lazy_loader 0.4 librosa 0.10.1 lightning-thunder 0.2.0.dev0 lightning-utilities 0.11.2 littleutils 0.2.2 llvmlite 0.42.0 locket 1.0.0 looseversion 1.3.0 marisa-trie 1.1.0 Markdown 3.6 markdown-it-py 3.0.0 MarkupSafe 2.1.5 matplotlib 3.8.4 matplotlib-inline 0.1.7 mdit-py-plugins 0.4.0 mdurl 0.1.2 mistune 3.0.2 mkl 2021.1.1 mkl-devel 2021.1.1 mkl-include 2021.1.1 ml-dtypes 0.4.0 mock 5.1.0 mpi4py 3.1.6 mpmath 1.3.0 msgpack 1.0.8 multidict 6.0.5 multiurl 0.3.1 murmurhash 1.0.10 mypy-extensions 1.0.0 natten 0.17.1 /opt/NATTEN/src nbclient 0.10.0 nbconvert 7.16.4 nbformat 5.10.4 nest-asyncio 1.6.0 netCDF4 1.6.5 networkx 3.3 ninja 1.11.1.1 notebook 6.4.10 numba 0.59.1 numcodecs 0.11.0 numpy 1.24.4 nvfuser 0.2.0a0+0ff5802 nvidia-cudnn-frontend 1.3.0 nvidia-dali-cuda110 1.38.0 nvidia-dali-cuda120 1.37.1 nvidia-ml-py 12.555.43 nvidia-modulus 0.3.0 nvidia-nvimgcodec-cu11 0.2.0.7 nvidia-nvimgcodec-cu12 0.2.0.7 nvidia-pyindex 1.0.9 nvtx 0.2.5 nx-cugraph 24.4.0 oauthlib 3.2.2 objsize 0.7.0 ogb 1.3.6 omegaconf 2.3.0 onnx 1.16.0 opencv 4.7.0 opt-einsum 3.3.0 optree 0.11.0 orjson 3.10.6 outdated 0.2.2 packaging 24.0 pandas 2.0.3 pandocfilters 1.5.1 parso 0.8.4 partd 1.4.1 pathspec 0.12.1 pexpect 4.9.0 Pillow 9.5.0 pip 24.0 platformdirs 4.2.1 pluggy 1.5.0 ply 3.11 polygraphy 0.49.10 pooch 1.8.1 preshed 3.0.9 prettytable 3.10.0 prometheus_client 0.20.0 prompt-toolkit 3.0.43 properscoring 0.1 proto-plus 1.24.0 protobuf 4.25.3 psutil 5.9.8 ptyprocess 0.7.0 pure-eval 0.2.2 py-cpuinfo 9.0.0 pyarrow 14.0.2 pyarrow-hotfix 0.6 pyasn1 0.6.0 pyasn1_modules 0.4.0 pybind11 2.12.0 pybind11_global 2.12.0 pycocotools 2.0+nv0.8.0 pycparser 2.22 pydantic 2.7.1 pydantic_core 2.18.2 pydot 1.4.2 pyerfa 2.0.1.4 Pygments 2.17.2 pygrib 2.1.5 pyjsparser 2.7.1 pylibcugraph 24.4.0 pylibcugraphops 24.4.0 pylibraft 24.4.0 pylibwholegraph 24.4.0 pymongo 4.8.0 pynvjitlink 0.1.13 pynvml 11.4.1 pyparsing 3.1.2 pyproj 3.6.1 pyshp 2.3.1 pytest 8.1.1 pytest-flakefinder 1.1.0 pytest-rerunfailures 14.0 pytest-shard 0.1.2 pytest-xdist 3.6.1 python-dateutil 2.9.0.post0 python-hostlist 1.23.0 pytorch-quantization 2.1.2 pytorch-triton 3.0.0+989adb9a2 pytz 2024.1 pyvista 0.43.9 PyYAML 6.0.1 pyzmq 26.0.3 raft-dask 24.4.0 rapids-dask-dependency 24.4.0a0 readline 6.2.4.1 rechunker 0.5.2 redis 5.0.7 referencing 0.35.1 regex 2024.4.28 requests 2.31.0 requests-oauthlib 2.0.0 rich 13.7.1 rmm 24.4.0 rpds-py 0.18.0 rsa 4.9 ruamel.yaml 0.18.6 ruamel.yaml.clib 0.2.8 ruff 0.4.8 s3fs 2024.6.0 safetensors 0.4.3 scikit_build_core 0.9.4 scikit-image 0.23.2 scikit-learn 1.4.2 scipy 1.13.0 scooby 0.10.0 Send2Trash 1.8.3 sentry-sdk 2.5.1 setproctitle 1.3.3 setuptools 68.2.2 shapely 2.0.4 six 1.16.0 smart-open 6.4.0 smmap 5.0.1 sortedcontainers 2.4.0 soundfile 0.12.1 soupsieve 2.5 soxr 0.3.7 spacy 3.7.4 spacy-legacy 3.0.12 spacy-loggers 1.0.5 sphinx_glpi_theme 0.6 srsly 2.4.8 stack-data 0.6.3 sympy 1.12 tabulate 0.9.0 tbb 2021.12.0 tblib 3.0.0 tensorboard 2.9.0 tensorboard-data-server 0.6.1 tensorboard-plugin-wit 1.8.1 tensorly 0.8.1 tensorly-torch 0.5.0 tensorrt 10.0.1 terminado 0.18.1 texttable 1.7.0 thinc 8.2.3 threadpoolctl 3.5.0 thriftpy2 0.4.20 tifffile 2024.5.22 timm 1.0.3 tinycss2 1.3.0 toml 0.10.2 tomli 2.0.1 toolz 0.12.1 torch 2.4.0a0+07cecf4168.nv24.5 torch_geometric 2.5.3 torch-harmonics 0.6.5 torch-tensorrt 2.4.0a0 torchdata 0.7.1 torchinfo 1.8.0 torchvision 0.19.0a0 tornado 6.4 tqdm 4.66.4 traitlets 5.9.0 transformer-engine 1.6.0+c81733f treelite 4.1.2 triton 2.3.1 typer 0.9.4 types-dataclasses 0.6.6 typing_extensions 4.11.0 tzdata 2024.1 tzlocal 5.2 ucx-py 0.37.0 urllib3 2.0.7 vtk 9.3.0 wandb 0.17.1 wasabi 1.1.2 wcwidth 0.2.13 weasel 0.3.4 weatherbench2 0.2.0 webencodings 0.5.1 Werkzeug 3.0.2 wheel 0.43.0 wrapt 1.16.0 xarray 2023.7.0 xarray-beam 0.6.3 xdoctest 1.0.2 xgboost 2.0.3 xyzservices 2024.6.0 yarl 1.9.4 zarr 2.17.2 zict 3.0.0 zipp 3.18.1 zstandard 0.22.0

jdwillard19 commented 4 months ago

Update, apparently this error does not occur when I use the command line script that inputs the same Data Config which is strange.

python evaluate.py --forecast_path=/pscratch/sd/j/jwillard/healpix_era5/results/nat1d_1deg_e512_w513_lr5em4cos/00/inference/forecasts.zarr \ --obs_path=/pscratch/sd/j/jwillard/healpix_era5/data/latlon_1deg_combined_wb.zarr \ --output_dir=/pscratch/sd/j/jwillard/FCN_exp/wb2_eval/ \ --output_file_prefix=test \ --input_chunks=init_time=1,lead_time=1 \ --runner=DirectRunner \ --fanout=27 \ --regions=all \ --eval_configs=deterministic \ --evaluate_climatology=False \ --evaluate_persistence=False \ --time_start=2020-01-01 \ --time_stop=2022-12-31 \ --pressure_level_suffixes=False \ --variables=geopotential,temperature,u_component_of_wind,v_component_of_wind,specific_humidity,2m_temperature,10m_u_component_of_wind,10m_v_component_of_wind,mean_sea_level_pressure \ --use_beam=True

jdwillard19 commented 3 months ago

Issue resolved,

My configuration file was passing in variables as <class 'omegaconf.listconfig.ListConfig'> and not the normal <class 'list'>, which is why the xarray couldn't read it as a list. I didn't notice because they print() and repr() identically. The issue wasn't "in beam" or "out of beam", just that I was running outside the configuration loading when I was running in Jupyter or within the WB2 cases where it worked. Sorry for the false flag