google-deepmind / graphcast

Apache License 2.0
4.36k stars 537 forks source link

Haiku needs all `hk.Module` must be initialized inside an `hk.transform` #53

Closed EloyAnguiano closed 5 months ago

EloyAnguiano commented 5 months ago

Hi. Ia am trying to execute the graphcast model in a conda enviornment built with the same packages version of a working execution at google collab but whenever I try to build the model at construct_wrapped_graphcast function returns this error:

Traceback (most recent call last):
  File "/home/eloy.anguiano/repos/graphcast/0.get_model.py", line 76, in <module>
    model = construct_wrapped_graphcast(model_config, task_config)
  File "/home/eloy.anguiano/repos/graphcast/0.get_model.py", line 58, in construct_wrapped_graphcast
    predictor = graphcast.GraphCast(model_config, task_config)
  File "/home/eloy.anguiano/repos/graphcast/graphcast/graphcast.py", line 261, in __init__
    self._grid2mesh_gnn = deep_typed_graph_net.DeepTypedGraphNet(
  File "/home/eloy.anguiano/miniconda3/envs/graphcast_iic/lib/python3.10/site-packages/haiku/_src/module.py", line 139, in __call__
    init(module, *args, **kwargs)
  File "/home/eloy.anguiano/miniconda3/envs/graphcast_iic/lib/python3.10/site-packages/haiku/_src/module.py", line 433, in wrapped
    raise ValueError(
ValueError: All `hk.Module`s must be initialized inside an `hk.transform`.

I checked that both dm-haiku versions (collaboratory and local) are 0.0.11. Is there any dockerfile to build a working environment or something like that? It is very difficult to run the same collab env at local.

How to reproduce:

from google.cloud import storage
from graphcast import autoregressive
from graphcast import casting
from graphcast import checkpoint
from graphcast import graphcast
from graphcast import normalization
import xarray

MODEL_VERSION = 'GraphCast_operational - ERA5-HRES 1979-2021 - resolution 0.25 - pressure levels 13 - mesh 2to6 - precipitation output only.npz'

# @title Authenticate with Google Cloud Storage
gcs_client = storage.Client.create_anonymous_client()
gcs_bucket = gcs_client.get_bucket("dm_graphcast")

with gcs_bucket.blob(f"params/{MODEL_VERSION}").open("rb") as f:
    ckpt = checkpoint.load(f, graphcast.CheckPoint)
params = ckpt.params
state = {}

model_config = ckpt.model_config
task_config = ckpt.task_config
print("Model description:\n", ckpt.description, "\n")
print("Model license:\n", ckpt.license, "\n")

with gcs_bucket.blob("stats/diffs_stddev_by_level.nc").open("rb") as f:
  diffs_stddev_by_level = xarray.load_dataset(f).compute()
with gcs_bucket.blob("stats/mean_by_level.nc").open("rb") as f:
  mean_by_level = xarray.load_dataset(f).compute()
with gcs_bucket.blob("stats/stddev_by_level.nc").open("rb") as f:
  stddev_by_level = xarray.load_dataset(f).compute()

def construct_wrapped_graphcast(
    model_config: graphcast.ModelConfig,
    task_config: graphcast.TaskConfig):
    """Constructs and wraps the GraphCast Predictor."""
    # Deeper one-step predictor.
    predictor = graphcast.GraphCast(model_config, task_config)

    # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to
    # from/to float32 to/from BFloat16.
    predictor = casting.Bfloat16Cast(predictor)

    # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from
    # BFloat16 happens after applying normalization to the inputs/targets.
    predictor = normalization.InputsAndResiduals(
        predictor,
        diffs_stddev_by_level=diffs_stddev_by_level,
        mean_by_level=mean_by_level,
        stddev_by_level=stddev_by_level)

    # Wraps everything so the one-step model can produce trajectories.
    predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True)
    return predictor

model = construct_wrapped_graphcast(model_config, task_config)
print("Done")
alvarosg commented 5 months ago

Thanks for your message, this is totally expected in Haiku, because as the error says all hk.Modules must be initialized inside an hk.transform, and GraphCast is a Haiku module.

See this bit of code in the GraphCast demo:

@hk.transform_with_state
def run_forward(model_config, task_config, inputs, targets_template, forcings):
  predictor = construct_wrapped_graphcast(model_config, task_config)
  return predictor(inputs, targets_template=targets_template, forcings=forcings)

which contains an example of how to use hk.transform.