google-deepmind / graphcast

Apache License 2.0
4.37k stars 538 forks source link

Error while forecasting using processed GFS data in Autoregressive rollout code #35

Closed HappyPolo closed 7 months ago

HappyPolo commented 7 months ago

I'm attempting to use processed GFS (Global Forecast System) data for forecasting. The processed data's coordinates, dimensions, and variables are nearly identical to the example data provided. However, when running the Autoregressive rollout code snippet, specifically this segment:

# Autoregressive rollout code snippet
assert model_config.resolution in (0, 360. / eval_inputs.sizes["lon"]), (
  "Model resolution doesn't match the data resolution. You likely want to "
  "re-filter the dataset list, and download the correct data.")

print("Inputs:  ", eval_inputs.dims.mapping)
print("Targets: ", eval_targets.dims.mapping)
print("Forcings:", eval_forcings.dims.mapping)

predictions = rollout.chunked_prediction(
    run_forward_jitted,
    rng=jax.random.PRNGKey(0),
    inputs=eval_inputs,
    targets_template=eval_targets * np.nan,
    forcings=eval_forcings)
predictions

I encounter an error message,The error message is as follows.

File ~/soft/anaconda3/lib/python3.10/site-packages/graphcast/rollout.py:69, in chunked_prediction(predictor_fn, rng, inputs, targets_template, forcings, num_steps_per_chunk, verbose)
     67 #   print(inputs)
     68   chunks_list = []
---> 69   for prediction_chunk in chunked_prediction_generator(
     70       predictor_fn=predictor_fn,
     71       rng=rng,
     72       inputs=inputs,
     73       targets_template=targets_template,
     74       forcings=forcings,
     75       num_steps_per_chunk=num_steps_per_chunk,
     76       verbose=verbose):
     77     chunks_list.append(jax.device_get(prediction_chunk))
     78   return xarray.concat(chunks_list, dim="time")

File ~/soft/anaconda3/lib/python3.10/site-packages/graphcast/rollout.py:165, in chunked_prediction_generator(predictor_fn, rng, inputs, targets_template, forcings, num_steps_per_chunk, verbose)
    163 # Make predictions for the chunk.
    164 rng, this_rng = jax.random.split(rng)
--> 165 predictions = predictor_fn(
    166     rng=this_rng,
    167     inputs=current_inputs,
    168     targets_template=current_targets_template,
    169     forcings=current_forcings)
    171 next_frame = xarray.merge([predictions, current_forcings])
    173 current_inputs = _get_next_inputs(current_inputs, next_frame)

/mnt/d/kt/project/yl/google-graphcast/code/test_graphcast.ipynb 单元格 15 line 6
     64 def drop_state(fn):
---> 65   return lambda **kw: fn(**kw)[0]

    [... skipping hidden 12 frame]

File ~/soft/anaconda3/lib/python3.10/site-packages/haiku/_src/transform.py:456, in transform_with_state.<locals>.apply_fn(params, state, rng, *args, **kwargs)
    454 with base.new_context(params=params, state=state, rng=rng) as ctx:
    455   try:
--> 456     out = f(*args, **kwargs)
    457   except jax.errors.UnexpectedTracerError as e:
    458     raise jax.errors.UnexpectedTracerError(unexpected_tracer_hint) from e

/mnt/d/kt/project/yl/google-graphcast/code/test_graphcast.ipynb 单元格 15 line 3
     27 @hk.transform_with_state
     28 def run_forward(model_config, task_config, inputs, targets_template, forcings):
     29   predictor = construct_wrapped_graphcast(model_config, task_config)
---> 30   return predictor(inputs, targets_template=targets_template, forcings=forcings)

File ~/soft/anaconda3/lib/python3.10/site-packages/graphcast/autoregressive.py:212, in Predictor.__call__(self, inputs, targets_template, forcings, **kwargs)
    209     one_step_prediction = hk.remat(one_step_prediction)
    211 # Loop (without unroll) with hk states in cell (jax.lax.scan won't do).
--> 212 _, flat_preds = hk.scan(one_step_prediction, inputs, scan_variables)
    214 # The result of scan will have an extra leading axis on all arrays,
    215 # corresponding to the target times in this case. We need to be prepared for
    216 # it when unflattening the arrays back into a Dataset:
    217 scan_result_template = (
    218     target_template.squeeze('time', drop=True)
    219     .expand_dims(time=targets_template.coords['time'], axis=0))

File ~/soft/anaconda3/lib/python3.10/site-packages/haiku/_src/stateful.py:643, in scan(f, init, xs, length, reverse, unroll)
    637 # We know that we don't need to thread params in and out, since for init we
    638 # have already created them (given that above we unroll one step of the scan)
    639 # and for apply we know they are immutable. As such we only need to thread the
    640 # state and rng in and out.
    642 init = (init, internal_state(params=False))
--> 643 (carry, state), ys = jax.lax.scan(
    644     stateful_fun, init, xs, length, reverse, unroll=unroll)
    645 update_internal_state(state)
    647 if running_init_fn:

    [... skipping hidden 9 frame]

File ~/soft/anaconda3/lib/python3.10/site-packages/haiku/_src/stateful.py:626, in scan.<locals>.stateful_fun(carry, x)
    623 with temporary_internal_state(state):
    624   with base.assert_no_new_parameters(), \
    625        base.push_jax_trace_level():
--> 626     carry, out = f(carry, x)
    627   reserve_up_to_full_rng_block()
    628   carry = (carry, internal_state(params=False))

File ~/soft/anaconda3/lib/python3.10/site-packages/graphcast/autoregressive.py:183, in Predictor.__call__.<locals>.one_step_prediction(inputs, scan_variables)
    181 # Add constant inputs:
    182 all_inputs = xarray.merge([constant_inputs, inputs])
--> 183 predictions: xarray.Dataset = self._predictor(
    184     all_inputs, target_template,
    185     forcings=forcings,
    186     **kwargs)
    188 next_frame = xarray.merge([predictions, forcings])
    189 next_inputs = self._update_inputs(inputs, next_frame)

File ~/soft/anaconda3/lib/python3.10/site-packages/graphcast/normalization.py:156, in InputsAndResiduals.__call__(self, inputs, targets_template, forcings, **kwargs)
    154 norm_inputs = normalize(inputs, self._scales, self._locations)
    155 norm_forcings = normalize(forcings, self._scales, self._locations)
--> 156 norm_predictions = self._predictor(
    157     norm_inputs, targets_template, forcings=norm_forcings, **kwargs)
    158 return xarray_tree.map_structure(
    159     lambda pred: self._unnormalize_prediction_and_add_input(inputs, pred),
    160     norm_predictions)

File ~/soft/anaconda3/lib/python3.10/site-packages/graphcast/casting.py:56, in Bfloat16Cast.__call__(self, inputs, targets_template, forcings, **kwargs)
     52   return self._predictor(inputs, targets_template, forcings, **kwargs)
     54 with bfloat16_variable_view():
     55   predictions = self._predictor(
---> 56       *_all_inputs_to_bfloat16(inputs, targets_template, forcings),
     57       **kwargs,)
     59 predictions_dtype = infer_floating_dtype(predictions)
     60 if predictions_dtype != jnp.bfloat16:

File ~/soft/anaconda3/lib/python3.10/site-packages/graphcast/casting.py:179, in _all_inputs_to_bfloat16(inputs, targets, forcings)
    164 def _all_inputs_to_bfloat16(
    165     inputs: xarray.Dataset,
    166     targets: xarray.Dataset,
   (...)
    177 #   data_vars = {key: value.values if hasattr(value, 'values') else value for key, value in inputs.items()}
    178 #   dataset = xr.Dataset(data_vars)
--> 179     return (inputs.astype(jnp.bfloat16),
    180             jax.tree_map(lambda x: x.astype(jnp.bfloat16), targets),
    181             forcings.astype(jnp.bfloat16))
AttributeError: 'dict' object has no attribute 'astype'

However, running the code with the example data works correctly. The specific error message encountered when forecasting with my processed GFS data. I would appreciate any assistance in resolving this issue

Stoby200 commented 7 months ago

I'm attempting to use processed GFS (Global Forecast System) data for forecasting. The processed data's coordinates, dimensions, and variables are nearly identical to the example data provided. However, when running the Autoregressive rollout code snippet, specifically this segment:

# Autoregressive rollout code snippet
assert model_config.resolution in (0, 360. / eval_inputs.sizes["lon"]), (
  "Model resolution doesn't match the data resolution. You likely want to "
  "re-filter the dataset list, and download the correct data.")

print("Inputs:  ", eval_inputs.dims.mapping)
print("Targets: ", eval_targets.dims.mapping)
print("Forcings:", eval_forcings.dims.mapping)

predictions = rollout.chunked_prediction(
    run_forward_jitted,
    rng=jax.random.PRNGKey(0),
    inputs=eval_inputs,
    targets_template=eval_targets * np.nan,
    forcings=eval_forcings)
predictions

I encounter an error message,The error message is as follows.

File ~/soft/anaconda3/lib/python3.10/site-packages/graphcast/rollout.py:69, in chunked_prediction(predictor_fn, rng, inputs, targets_template, forcings, num_steps_per_chunk, verbose)
     67 #   print(inputs)
     68   chunks_list = []
---> 69   for prediction_chunk in chunked_prediction_generator(
     70       predictor_fn=predictor_fn,
     71       rng=rng,
     72       inputs=inputs,
     73       targets_template=targets_template,
     74       forcings=forcings,
     75       num_steps_per_chunk=num_steps_per_chunk,
     76       verbose=verbose):
     77     chunks_list.append(jax.device_get(prediction_chunk))
     78   return xarray.concat(chunks_list, dim="time")

File ~/soft/anaconda3/lib/python3.10/site-packages/graphcast/rollout.py:165, in chunked_prediction_generator(predictor_fn, rng, inputs, targets_template, forcings, num_steps_per_chunk, verbose)
    163 # Make predictions for the chunk.
    164 rng, this_rng = jax.random.split(rng)
--> 165 predictions = predictor_fn(
    166     rng=this_rng,
    167     inputs=current_inputs,
    168     targets_template=current_targets_template,
    169     forcings=current_forcings)
    171 next_frame = xarray.merge([predictions, current_forcings])
    173 current_inputs = _get_next_inputs(current_inputs, next_frame)

/mnt/d/kt/project/yl/google-graphcast/code/test_graphcast.ipynb 单元格 15 line 6
     64 def drop_state(fn):
---> 65   return lambda **kw: fn(**kw)[0]

    [... skipping hidden 12 frame]

File ~/soft/anaconda3/lib/python3.10/site-packages/haiku/_src/transform.py:456, in transform_with_state.<locals>.apply_fn(params, state, rng, *args, **kwargs)
    454 with base.new_context(params=params, state=state, rng=rng) as ctx:
    455   try:
--> 456     out = f(*args, **kwargs)
    457   except jax.errors.UnexpectedTracerError as e:
    458     raise jax.errors.UnexpectedTracerError(unexpected_tracer_hint) from e

/mnt/d/kt/project/yl/google-graphcast/code/test_graphcast.ipynb 单元格 15 line 3
     27 @hk.transform_with_state
     28 def run_forward(model_config, task_config, inputs, targets_template, forcings):
     29   predictor = construct_wrapped_graphcast(model_config, task_config)
---> 30   return predictor(inputs, targets_template=targets_template, forcings=forcings)

File ~/soft/anaconda3/lib/python3.10/site-packages/graphcast/autoregressive.py:212, in Predictor.__call__(self, inputs, targets_template, forcings, **kwargs)
    209     one_step_prediction = hk.remat(one_step_prediction)
    211 # Loop (without unroll) with hk states in cell (jax.lax.scan won't do).
--> 212 _, flat_preds = hk.scan(one_step_prediction, inputs, scan_variables)
    214 # The result of scan will have an extra leading axis on all arrays,
    215 # corresponding to the target times in this case. We need to be prepared for
    216 # it when unflattening the arrays back into a Dataset:
    217 scan_result_template = (
    218     target_template.squeeze('time', drop=True)
    219     .expand_dims(time=targets_template.coords['time'], axis=0))

File ~/soft/anaconda3/lib/python3.10/site-packages/haiku/_src/stateful.py:643, in scan(f, init, xs, length, reverse, unroll)
    637 # We know that we don't need to thread params in and out, since for init we
    638 # have already created them (given that above we unroll one step of the scan)
    639 # and for apply we know they are immutable. As such we only need to thread the
    640 # state and rng in and out.
    642 init = (init, internal_state(params=False))
--> 643 (carry, state), ys = jax.lax.scan(
    644     stateful_fun, init, xs, length, reverse, unroll=unroll)
    645 update_internal_state(state)
    647 if running_init_fn:

    [... skipping hidden 9 frame]

File ~/soft/anaconda3/lib/python3.10/site-packages/haiku/_src/stateful.py:626, in scan.<locals>.stateful_fun(carry, x)
    623 with temporary_internal_state(state):
    624   with base.assert_no_new_parameters(), \
    625        base.push_jax_trace_level():
--> 626     carry, out = f(carry, x)
    627   reserve_up_to_full_rng_block()
    628   carry = (carry, internal_state(params=False))

File ~/soft/anaconda3/lib/python3.10/site-packages/graphcast/autoregressive.py:183, in Predictor.__call__.<locals>.one_step_prediction(inputs, scan_variables)
    181 # Add constant inputs:
    182 all_inputs = xarray.merge([constant_inputs, inputs])
--> 183 predictions: xarray.Dataset = self._predictor(
    184     all_inputs, target_template,
    185     forcings=forcings,
    186     **kwargs)
    188 next_frame = xarray.merge([predictions, forcings])
    189 next_inputs = self._update_inputs(inputs, next_frame)

File ~/soft/anaconda3/lib/python3.10/site-packages/graphcast/normalization.py:156, in InputsAndResiduals.__call__(self, inputs, targets_template, forcings, **kwargs)
    154 norm_inputs = normalize(inputs, self._scales, self._locations)
    155 norm_forcings = normalize(forcings, self._scales, self._locations)
--> 156 norm_predictions = self._predictor(
    157     norm_inputs, targets_template, forcings=norm_forcings, **kwargs)
    158 return xarray_tree.map_structure(
    159     lambda pred: self._unnormalize_prediction_and_add_input(inputs, pred),
    160     norm_predictions)

File ~/soft/anaconda3/lib/python3.10/site-packages/graphcast/casting.py:56, in Bfloat16Cast.__call__(self, inputs, targets_template, forcings, **kwargs)
     52   return self._predictor(inputs, targets_template, forcings, **kwargs)
     54 with bfloat16_variable_view():
     55   predictions = self._predictor(
---> 56       *_all_inputs_to_bfloat16(inputs, targets_template, forcings),
     57       **kwargs,)
     59 predictions_dtype = infer_floating_dtype(predictions)
     60 if predictions_dtype != jnp.bfloat16:

File ~/soft/anaconda3/lib/python3.10/site-packages/graphcast/casting.py:179, in _all_inputs_to_bfloat16(inputs, targets, forcings)
    164 def _all_inputs_to_bfloat16(
    165     inputs: xarray.Dataset,
    166     targets: xarray.Dataset,
   (...)
    177 #   data_vars = {key: value.values if hasattr(value, 'values') else value for key, value in inputs.items()}
    178 #   dataset = xr.Dataset(data_vars)
--> 179     return (inputs.astype(jnp.bfloat16),
    180             jax.tree_map(lambda x: x.astype(jnp.bfloat16), targets),
    181             forcings.astype(jnp.bfloat16))
AttributeError: 'dict' object has no attribute 'astype'

However, running the code with the example data works correctly. The specific error message encountered when forecasting with my processed GFS data. I would appreciate any assistance in resolving this issue

Could you share how you are processing your GFS data? I am also trying to initialise with the GFS and am having issues.

Aquila96 commented 7 months ago

Could you provide a full dump of the example_batch? We encountered no issues running the autoregressive rollout by replicating the exact structure of the example_batch

HappyPolo commented 7 months ago

您能提供完整的转储吗example_batch?通过复制模型的精确结构,我们在运行自回归部署时没有遇到任何问题。example_batch Thank you for your response. The issue has been resolved now; it seems that there was an error in handling the input data.