Closed RogerJL closed 11 months ago
Workaround
b_np = Observe._make(jax_to_numpy(tuple(a_jax)))
c_jax = Observe._make(numpy_to_jax(tuple(b_np)))
Thanks for reporting that, I'm surprised as this code was originally based on the brax google project that I thought used NamedTuples
Could you test with flax struct as well to see if they work
Feel free to make a PR and add a test to check your solution works
Hmm... it looks like there are multiple ways to write NamedTuples the one I found and used might not even be a "namedtuple"!
This works! Or does it...
from collections import namedtuple
import jax
from gymnasium.experimental.wrappers.jax_to_numpy import jax_to_numpy, numpy_to_jax
Observe2 = namedtuple(
'Observe2',
['board', 'player'],
)
a_jax = Observe2(board=jax.numpy.zeros(shape=(8, 8)),
player=jax.numpy.zeros(8)
)
b_np = jax_to_numpy(a_jax)
c_jax = numpy_to_jax(b_np)
The NamedTuple style was copied from gymnasium/envs/tabular/blackjack.py:28 So it might still be a gymnasium bug!
But now both styles fail again...
A minimal testcase showing the problem
from collections import namedtuple
Circle2 = namedtuple(
'Circle2',
['center_x', 'center_y', 'radius']
)
a=Circle2(1,2,3)
b=type(a)(2,3,4)
c=type(a)(x for x in range(3))
Traceback (most recent call last):
File "C:\Program Files\JetBrains\PyCharm 2023.1.2\plugins\python\helpers\pydev\pydevconsole.py", line 364, in runcode
coro = func()
^^^^^^
File "<input>", line 1, in <module>
TypeError: Circle2.__new__() missing 2 required positional arguments: 'center_y' and 'radius'
c=type(a)(*(x for x in range(3)))
c
Circle2(center_x=0, center_y=1, radius=2)
Tried to create a PR, fails
"Can't create a new pull request: Push failed: remote: Permission to Farama-Foundation/Gymnasium.git denied to RogerJL. unable to access 'https://github.com/Farama-Foundation/Gymnasium.git/': The requested URL returned error: 403"
You will need to fork the repo, make a branch with your changes on that fork then make a PR on this gymnasium repo
@pseudo-rnd-thoughts Created a draft PR, guess there should be some documentation changes too but I find no good place. Should PR be merged to a single commit - can that be done easily?
Describe the bug
wrapper conversion jax_to_numpy fails for NamedTuples in
As there are no specific handling of NamedTuples installed, the Iterable will be used
This tries to use positional creation of NamedTuple, will fail as you should use NamedTupleClass._make(Iterable) I think it should look something like this
guess there should be a numpy_to_jax too?
Code example
System info
pip install gymnasium Gymnasium 0.29.1 Running on Windows 11 Pro (should not matter) Python 3.11
Additional context
No response
Checklist