jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.46k stars 2.8k forks source link

How to replay an object_ replay_buffer become types that support jax data #15093

Open dengdengan opened 1 year ago

dengdengan commented 1 year ago

Please:

        replay_buffer = ReplayBuffer(env.observation_space, env.action_space,
                                     FLAGS.max_steps)

I want to pass it as an incoming parameter to the function:

agent, update_info = agent.update(batch, FLAGS.utd_ratio,replay_buffer)

But the update function requires jit acceleration

@partial(jax.jit, static_argnames='utd_ratio')
    def update(self, batch: DatasetDict, utd_ratio: int,replay_buffer):

        new_agent = self
        for i in range(utd_ratio):

            def slice(x):
                assert x.shape[0] % utd_ratio == 0
                batch_size = x.shape[0] // utd_ratio
                return x[batch_size * i:batch_size * (i + 1)]

            mini_batch = jax.tree_util.tree_map(slice, batch)
            print(replay_buffer)
            new_agent, critic_info = self.update_critic(new_agent, mini_batch,replay_buffer)

            pro=(critic_info['prios'])
            pro1=jnp.array([pro])
            indices=jnp.array(batch['indices'])
            replay_buffer.update_priorities(batch_indices=indices,batch_priorities=pro1)

        new_agent, actor_info = self.update_actor(new_agent, mini_batch)
        new_agent, temp_info = self.update_temperature(new_agent,
                                                       actor_info['entropy'])

        return new_agent, {**actor_info, **critic_info, **temp_info}

Error will be reported when I run directly

TypeError: Argument '<rl.data.replay_buffer.ReplayBuffer object at 0x7f0fd441ef70>' of type <class 'rl.data.replay_buffer.ReplayBuffer'> is not a valid JAX type

Thank you very much for your help!Thank you very much for your help!Thank you very much for your help!!!!!!

jakevdp commented 1 year ago

Hi, thanks for the question! It's hard to tell how to help without more information. Could you edit your question to add a minimal reproducible example? Package imports are important: for example I have no idea what ReplayBuffer is or what package it might be coming from. Thanks!

dengdengan commented 1 year ago

The ReplayBuffer is a class used for experience replay in reinforcement learning, and an object of this class is created as replay_buffer. I want to pass this replay_buffer object as an input parameter to a jit-accelerated update function, but the replay_buffer object does not support the data format of jax. How can I convert the replay_buffer to a data format that supports jax? Thank you for your help. Here is the general content of the ReplayBuffer class.

`class ReplayBuffer(Dataset):

def __init__(self,
             observation_space: gym.Space,
             action_space: gym.Space,
             capacity: int,
             alpha:float = 0.6,
             beta_start:float = 0.4,
             beta_frames: float = 1e6,
             next_observation_space: Optional[gym.Space] = None):
    if next_observation_space is None:
        next_observation_space = observation_space

    observation_data = _init_replay_dict(observation_space, capacity)
    next_observation_data = _init_replay_dict(next_observation_space,
                                              capacity)
    dataset_dict = dict(
        observations=observation_data,
        next_observations=next_observation_data,
        actions=np.empty((capacity, *action_space.shape),
                         dtype=action_space.dtype),
        rewards=np.empty((capacity, ), dtype=np.float32),
        dones=np.empty((capacity, ), dtype=bool),
    )

    super().__init__(dataset_dict)

    #self._size = 0
    self._capacity = capacity
    self._insert_index = 0
    self.alpha = alpha
    self.beta_start = beta_start
    self.beta_frames = beta_frames
    self.frame = 1
    self.capacity = capacity
    self.buffer = []
    self.pos = 0
    self.priorities = np.zeros((capacity,),dtype = np.float32)

def beta_by_frames(self, frame_idx):
    return min(1.0, self.beta_start + frame_idx * (1.0 - self.beta_start) / self.beta_frames)

def insert(self, data_dict: DatasetDict):
    state=data_dict['observations']
    next_state=data_dict['next_observations']
    assert state.ndim == next_state.ndim
    state = np.expand_dims(state, 0)
    next_state = np.expand_dims(next_state, 0)
    max_prio = self.priorities.max() if self.buffer else 1.0  # gives max priority if buffer is not empty else 1
    if len(self.buffer) < self.capacity:
        self.buffer.append((data_dict['observations'], data_dict['actions'], data_dict['rewards'] ,data_dict['next_observations'], data_dict['dones']))
    else:
        self.buffer[self.pos] = (data_dict['observations'], data_dict['actions'], data_dict['rewards'] ,data_dict['next_observations'], data_dict['dones'])

    self.priorities[self.pos] = max_prio
    self.pos = (self.pos + 1) % self.capacity  # lets the pos circle in the ranges of capacity if pos+1 > cap --> new posi = 0

def __len__(self) -> int:
    return self.buffer

def update_priorities(self, batch_indices, batch_priorities):
     for idx, prio in zip(batch_indices, batch_priorities):
        self.priorities[idx] = abs(prio)

def sample(self, batch_size)-> frozen_dict.FrozenDict:
    N = len(self.buffer)
    if N == self.capacity:
        prios = self.priorities
    else:
        prios = self.priorities[:self.pos]

    # calc P = p^a/sum(p^a)
    probs = prios ** self.alpha
    P = probs / probs.sum()

    # gets the indices depending on the probability p
    indices = np.random.choice(N, batch_size, p=P)
    samples = [self.buffer[idx] for idx in indices]

    beta = self.beta_by_frames(self.frame)
    self.frame += 1

    # Compute importance-sampling weight
    weights = (N * P[indices]) ** (-beta)
    # normalize weights
    weights /= weights.max()
    weights = np.array(weights, dtype=np.float32)
    observations, actions, rewards, next_observations, dones = zip(*samples)
    batch = {'observations': observations,
                     'actions': actions,
                        'rewards': rewards,
                        'next_observations': next_observations,
                        'dones': dones,
                        'indices': indices,
                        'weights': weights}
    #print(batch['weights'])
    for k,v in batch.items():
        batch[k] = np.array(v)
    return frozen_dict.freeze(batch)`