danijar / dreamerv3

Mastering Diverse Domains through World Models
https://danijar.com/dreamerv3
MIT License
1.28k stars 219 forks source link

Addition of "tree_map(np.asarray, value)" to "_convert_inps" function in "JAXAgent" #89

Closed ExuberantWitness closed 11 months ago

ExuberantWitness commented 1 year ago

Hello,

I've identified a potential issue in the _convert_inps method of the JAXAgent class. When len(devices) == 1, it appears necessary to add value = tree_map(np.asarray, value). This addition could prevent exception errors when using custom programs that are compatible with the gym framework.

Here's the recommended modification:

def _convert_inps(self, value, devices):
  if len(devices) == 1:
    value = tree_map(np.asarray, value)  # Addition for gym compatibility
    value = jax.device_put(value, devices[0])
  else:
    check = tree_map(lambda x: len(x) % len(devices) == 0, value)
    if not all(jax.tree_util.tree_leaves(check)):
      shapes = tree_map(lambda x: x.shape, value)
      raise ValueError(
          f'Batch must by divisible by {len(devices)} devices: {shapes}')
    # TODO: Avoid the reshape?
    value = tree_map(
        lambda x: x.reshape((len(devices), -1) + x.shape[1:]), value)
    shards = []
    for i in range(len(devices)):
      shards.append(tree_map(lambda x: x[i], value))
    value = jax.device_put_sharded(shards, devices)
  return value

Thank you for considering this issue.

Best,

danijar commented 11 months ago

Hi, thanks for pointing that out. I would do that in an environment wrapper. The embodied framework defines the environment API to use Numpy types for everything.