Farama-Foundation / Gymnasium

An API standard for single-agent reinforcement learning environments, with popular reference environments and related utilities (formerly Gym)
https://gymnasium.farama.org
MIT License
7.32k stars 817 forks source link

[Bug Report] TypeError: <NamedTuple>.__new__() missing 1 required positional argument: '<argument>' #780

Closed RogerJL closed 11 months ago

RogerJL commented 11 months ago

Describe the bug

wrapper conversion jax_to_numpy fails for NamedTuples in

As there are no specific handling of NamedTuples installed, the Iterable will be used

@jax_to_numpy.register(abc.Iterable)
def _iterable_jax_to_numpy(
    value: Iterable[np.ndarray | Any],
) -> Iterable[jax.Array | Any]:
    """Converts an Iterable from Numpy arrays to an iterable of Jax Array."""
    return type(value)(jax_to_numpy(v) for v in value)

This tries to use positional creation of NamedTuple, will fail as you should use NamedTupleClass._make(Iterable) I think it should look something like this

@jax_to_numpy.register(NamedTuple)
def _namedtuple_jax_to_numpy(
    value: NamedTuple,
) -> NamedTuple:
    """Converts an NamedTuple from Numpy arrays to a NamedTuple of Jax Array."""
    return type(value)._make(jax_to_numpy.jax_to_numpy(v) for v in value)

guess there should be a numpy_to_jax too?

Code example

from typing import NamedTuple

import jax
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy, numpy_to_jax

class Observe(NamedTuple):
    board: jax.Array
    player: jax.Array

a_jax = Observe(board=jax.numpy.zeros(shape=(8, 8)),
                player=jax.numpy.zeros(8))

b_np = jax_to_numpy(a_jax)
c_jax = numpy_to_jax(b_np)

System info

pip install gymnasium Gymnasium 0.29.1 Running on Windows 11 Pro (should not matter) Python 3.11

Additional context

No response

Checklist

RogerJL commented 11 months ago

Workaround

b_np = Observe._make(jax_to_numpy(tuple(a_jax)))
c_jax = Observe._make(numpy_to_jax(tuple(b_np)))
pseudo-rnd-thoughts commented 11 months ago

Thanks for reporting that, I'm surprised as this code was originally based on the brax google project that I thought used NamedTuples Could you test with flax struct as well to see if they work

Feel free to make a PR and add a test to check your solution works

RogerJL commented 11 months ago

Hmm... it looks like there are multiple ways to write NamedTuples the one I found and used might not even be a "namedtuple"!

This works! Or does it...

from collections import namedtuple

import jax
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy, numpy_to_jax

Observe2 = namedtuple(
    'Observe2',
    ['board', 'player'],
)

a_jax = Observe2(board=jax.numpy.zeros(shape=(8, 8)),
                 player=jax.numpy.zeros(8)
)

b_np = jax_to_numpy(a_jax)
c_jax = numpy_to_jax(b_np)

The NamedTuple style was copied from gymnasium/envs/tabular/blackjack.py:28 So it might still be a gymnasium bug!

But now both styles fail again...

RogerJL commented 11 months ago

A minimal testcase showing the problem

from collections import namedtuple

Circle2 = namedtuple(
    'Circle2',
    ['center_x', 'center_y', 'radius']
)
a=Circle2(1,2,3)
b=type(a)(2,3,4)
c=type(a)(x for x in range(3))
Traceback (most recent call last):
  File "C:\Program Files\JetBrains\PyCharm 2023.1.2\plugins\python\helpers\pydev\pydevconsole.py", line 364, in runcode
    coro = func()
           ^^^^^^
  File "<input>", line 1, in <module>
TypeError: Circle2.__new__() missing 2 required positional arguments: 'center_y' and 'radius'
c=type(a)(*(x for x in range(3)))
c
Circle2(center_x=0, center_y=1, radius=2)
RogerJL commented 11 months ago

Tried to create a PR, fails

"Can't create a new pull request: Push failed: remote: Permission to Farama-Foundation/Gymnasium.git denied to RogerJL. unable to access 'https://github.com/Farama-Foundation/Gymnasium.git/': The requested URL returned error: 403"

pseudo-rnd-thoughts commented 11 months ago

You will need to fork the repo, make a branch with your changes on that fork then make a PR on this gymnasium repo

RogerJL commented 11 months ago

@pseudo-rnd-thoughts Created a draft PR, guess there should be some documentation changes too but I find no good place. Should PR be merged to a single commit - can that be done easily?