I've identified a potential issue in the _convert_inps method of the JAXAgent class. When len(devices) == 1, it appears necessary to add value = tree_map(np.asarray, value). This addition could prevent exception errors when using custom programs that are compatible with the gym framework.
Here's the recommended modification:
def _convert_inps(self, value, devices):
if len(devices) == 1:
value = tree_map(np.asarray, value) # Addition for gym compatibility
value = jax.device_put(value, devices[0])
else:
check = tree_map(lambda x: len(x) % len(devices) == 0, value)
if not all(jax.tree_util.tree_leaves(check)):
shapes = tree_map(lambda x: x.shape, value)
raise ValueError(
f'Batch must by divisible by {len(devices)} devices: {shapes}')
# TODO: Avoid the reshape?
value = tree_map(
lambda x: x.reshape((len(devices), -1) + x.shape[1:]), value)
shards = []
for i in range(len(devices)):
shards.append(tree_map(lambda x: x[i], value))
value = jax.device_put_sharded(shards, devices)
return value
Hi, thanks for pointing that out. I would do that in an environment wrapper. The embodied framework defines the environment API to use Numpy types for everything.
Hello,
I've identified a potential issue in the
_convert_inps
method of theJAXAgent
class. Whenlen(devices) == 1
, it appears necessary to addvalue = tree_map(np.asarray, value)
. This addition could prevent exception errors when using custom programs that are compatible with the gym framework.Here's the recommended modification:
Thank you for considering this issue.
Best,