NVIDIA / earth2studio

Open-source deep-learning framework for exploring, building and deploying AI weather/climate workflows.
https://nvidia.github.io/earth2studio/
Apache License 2.0
73 stars 23 forks source link

🐛[BUG]: Inifinite recurrsion with batch coords if batch is missing from coord system #71

Closed NickGeneva closed 2 months ago

NickGeneva commented 2 months ago

Version

main

On which installation method(s) does this occur?

Source

Describe the issue

Should check if batch exists in input and informative error if not. That or input coords should be a function that returns a copy of the coord dict.

from earth2studio.models.px import SFNO
model = SFNO(None, None, None)
n_coords = model.input_coords
del in_coords['batch']

out_coords = model.output_coords(in_coords)
ile "/code/earth2studio/earth2studio/models/batch.py", line 330, in _wrapper
    flatten_coords, batched_coords = self._compress_batch(model, input_coords)
  File "/code/earth2studio/earth2studio/models/batch.py", line 272, in _compress_batch
    and next(iter(model.output_coords(model.input_coords))) != "batch"
  File "/code/earth2studio/earth2studio/models/batch.py", line 330, in _wrapper
    flatten_coords, batched_coords = self._compress_batch(model, input_coords)
  File "/code/earth2studio/earth2studio/models/batch.py", line 272, in _compress_batch
    and next(iter(model.output_coords(model.input_coords))) != "batch"
  File "/code/earth2studio/earth2studio/models/batch.py", line 330, in _wrapper
    flatten_coords, batched_coords = self._compress_batch(model, input_coords)
  File "/code/earth2studio/earth2studio/models/batch.py", line 272, in _compress_batch
    and next(iter(model.output_coords(model.input_coords))) != "batch"
  File "/code/earth2studio/earth2studio/models/batch.py", line 330, in _wrapper
    flatten_coords, batched_coords = self._compress_batch(model, input_coords)
  File "/code/earth2studio/earth2studio/models/batch.py", line 272, in _compress_batch
    and next(iter(model.output_coords(model.input_coords))) != "batch"
  File "/code/earth2studio/earth2studio/models/batch.py", line 330, in _wrapper
    flatten_coords, batched_coords = self._compress_batch(model, input_coords)