sotetsuk / pgx

♟️ Vectorized RL game environments in JAX
http://sotets.uk/pgx/
Apache License 2.0
372 stars 23 forks source link

Loading Trained AlphaZero Model #1172

Closed sr5434 closed 4 months ago

sr5434 commented 5 months ago

Hey all! I trained AlphaZero on Kuhn Poker with the provided example. I am now trying to adapt the baseline loader to load my model, but it returns this error:

Traceback (most recent call last):
  File "/Users/samir/Documents/Apps/alphaZero/test_model.py", line 130, in <module>
    print(model(
          ^^^^^^
  File "/Users/samir/Documents/Apps/alphaZero/test_model.py", line 19, in apply
    (logits, value), _ = forward.apply(model_params, model_state, obs, is_eval=True)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/haiku/_src/multi_transform.py", line 296, in apply_fn
    return f.apply(params, state, None, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/haiku/_src/transform.py", line 456, in apply_fn
    out = f(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^
  File "/Users/samir/Documents/Apps/alphaZero/test_model.py", line 13, in forward_fn
    policy_out, value_out = net(x, is_training=not is_eval, test_local_stats=False)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/haiku/_src/module.py", line 458, in wrapped
    out = f(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/haiku/_src/module.py", line 299, in run_interceptors
    return bound_method(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/samir/Documents/Apps/alphaZero/test_model.py", line 82, in __call__
    logits = hk.Linear(self.num_actions)(logits)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/haiku/_src/module.py", line 458, in wrapped
    out = f(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/haiku/_src/module.py", line 299, in run_interceptors
    return bound_method(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/haiku/_src/basic.py", line 179, in __call__
    w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/haiku/_src/base.py", line 685, in get_parameter
    raise ValueError(
ValueError: 'az_net/linear/w' with retrieved shape (14, 4) does not match shape=[2, 4] dtype=dtype('float32')

This is my code:

import pickle
from typing import NamedTuple

from pydantic import BaseModel
import jax
import jax.numpy as jnp
import haiku as hk
import pgx

def _make_az_baseline_model(model_args, model_params, model_state):
    def forward_fn(x, is_eval=False):
        net = _create_az_model_v0(num_actions=4, num_channels=model_args.num_channels, num_layers=model_args.num_layers)
        policy_out, value_out = net(x, is_training=not is_eval, test_local_stats=False)
        return policy_out, value_out

    forward = hk.without_apply_rng(hk.transform_with_state(forward_fn))

    def apply(obs):
        (logits, value), _ = forward.apply(model_params, model_state, obs, is_eval=True)
        return logits, value

    return apply

def _create_az_model_v0(
        num_actions,
        num_channels: int = 128,
        num_layers: int = 6,
        resnet_v2: bool = True,
):
    class BlockV2(hk.Module):
        def __init__(self, num_channels, name="BlockV2"):
            super(BlockV2, self).__init__(name=name)
            self.num_channels = num_channels

        def __call__(self, x, is_training, test_local_stats):

            i = x
            x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats)
            x = jax.nn.relu(x)
            x = hk.Conv2D(self.num_channels, kernel_shape=3)(x)
            x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats)
            x = jax.nn.relu(x)
            x = hk.Conv2D(self.num_channels, kernel_shape=3)(x)
            return x + i

    class AZNet(hk.Module):
        """AlphaZero NN architecture."""

        def __init__(
                self,
                num_actions,
                num_channels: int,
                num_layers: int,
                resnet_v2=True,
                name="az_net",
        ):
            super().__init__(name=name)
            self.num_actions = num_actions
            self.num_channels = num_channels
            self.num_layers = num_layers
            self.resnet_v2 = True
            self.resnet_cls = BlockV2

        def __call__(self, x, is_training=False, test_local_stats=False):
            x = x.reshape((x.shape[0], x.shape[1], 1))
            x = x.astype(jnp.float32)
            x = hk.Conv2D(self.num_channels, kernel_shape=2)(x)

            for i in range(self.num_layers):
                x = self.resnet_cls(self.num_channels, name=f"block_{i}")(
                    x, is_training, test_local_stats
                )
            x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats)
            x = jax.nn.relu(x)

            # policy head
            logits = hk.Conv2D(output_channels=2, kernel_shape=1)(x)
            logits = hk.BatchNorm(True, True, 0.9)(logits, is_training, test_local_stats)
            logits = jax.nn.relu(logits)
            logits = hk.Flatten()(logits)
            logits = hk.Linear(self.num_actions)(logits)

            # value head
            v = hk.Conv2D(output_channels=1, kernel_shape=1)(x)
            v = hk.BatchNorm(True, True, 0.9)(v, is_training, test_local_stats)
            v = jax.nn.relu(v)
            v = hk.Flatten()(v)
            v = hk.Linear(self.num_channels)(v)
            v = jax.nn.relu(v)
            v = hk.Linear(1)(v)
            v = jnp.tanh(v)
            v = v.reshape((-1,))

            return logits, v

    return AZNet(num_actions, num_channels, num_layers, resnet_v2)

class Config(BaseModel):
    env_id: pgx.EnvId = "kuhn_poker"
    seed: int = 0
    max_num_iters: int = 50000
    # network params
    num_channels: int = 128
    num_layers: int = 6
    resnet_v2: bool = True
    # selfplay params
    selfplay_batch_size: int = 1
    num_simulations: int = 32
    max_num_steps: int = 256
    # training params
    training_batch_size: int = 4096
    learning_rate: float = 0.001
    # eval params
    eval_interval: int = 5

class Sample(NamedTuple):
    obs: jnp.ndarray
    policy_tgt: jnp.ndarray
    value_tgt: jnp.ndarray
    mask: jnp.ndarray
config: Config = Config()
env = pgx.make(config.env_id)

if __name__ == "__main__":
    with open("/Users/samir/Documents/Apps/alphaZero/model.ckpt", "rb") as f:
        d = pickle.load(f)
    model = _make_az_baseline_model(d["config"], d["model"][0], d["model"][1])
    print(model(
        jnp.array([[0.],
         [0.],
         [1.],
         [1.],
         [0.],
         [1.],
         [0.]])
    ))

How do I fix this?

sr5434 commented 5 months ago

@sotetsuk do you know what is causing this?

sotetsuk commented 4 months ago

Hi, thank you for your comment and very sorry for the late response 🙏 Sorry but we do not provide any support for training issues by users 🙏 😭 It looks just a simple shape problems, you may check the layer output step by step.