Open dengdengan opened 1 year ago
Hi, thanks for the question! It's hard to tell how to help without more information. Could you edit your question to add a minimal reproducible example? Package imports are important: for example I have no idea what ReplayBuffer
is or what package it might be coming from. Thanks!
The ReplayBuffer is a class used for experience replay in reinforcement learning, and an object of this class is created as replay_buffer. I want to pass this replay_buffer object as an input parameter to a jit-accelerated update function, but the replay_buffer object does not support the data format of jax. How can I convert the replay_buffer to a data format that supports jax? Thank you for your help. Here is the general content of the ReplayBuffer class.
`class ReplayBuffer(Dataset):
def __init__(self,
observation_space: gym.Space,
action_space: gym.Space,
capacity: int,
alpha:float = 0.6,
beta_start:float = 0.4,
beta_frames: float = 1e6,
next_observation_space: Optional[gym.Space] = None):
if next_observation_space is None:
next_observation_space = observation_space
observation_data = _init_replay_dict(observation_space, capacity)
next_observation_data = _init_replay_dict(next_observation_space,
capacity)
dataset_dict = dict(
observations=observation_data,
next_observations=next_observation_data,
actions=np.empty((capacity, *action_space.shape),
dtype=action_space.dtype),
rewards=np.empty((capacity, ), dtype=np.float32),
dones=np.empty((capacity, ), dtype=bool),
)
super().__init__(dataset_dict)
#self._size = 0
self._capacity = capacity
self._insert_index = 0
self.alpha = alpha
self.beta_start = beta_start
self.beta_frames = beta_frames
self.frame = 1
self.capacity = capacity
self.buffer = []
self.pos = 0
self.priorities = np.zeros((capacity,),dtype = np.float32)
def beta_by_frames(self, frame_idx):
return min(1.0, self.beta_start + frame_idx * (1.0 - self.beta_start) / self.beta_frames)
def insert(self, data_dict: DatasetDict):
state=data_dict['observations']
next_state=data_dict['next_observations']
assert state.ndim == next_state.ndim
state = np.expand_dims(state, 0)
next_state = np.expand_dims(next_state, 0)
max_prio = self.priorities.max() if self.buffer else 1.0 # gives max priority if buffer is not empty else 1
if len(self.buffer) < self.capacity:
self.buffer.append((data_dict['observations'], data_dict['actions'], data_dict['rewards'] ,data_dict['next_observations'], data_dict['dones']))
else:
self.buffer[self.pos] = (data_dict['observations'], data_dict['actions'], data_dict['rewards'] ,data_dict['next_observations'], data_dict['dones'])
self.priorities[self.pos] = max_prio
self.pos = (self.pos + 1) % self.capacity # lets the pos circle in the ranges of capacity if pos+1 > cap --> new posi = 0
def __len__(self) -> int:
return self.buffer
def update_priorities(self, batch_indices, batch_priorities):
for idx, prio in zip(batch_indices, batch_priorities):
self.priorities[idx] = abs(prio)
def sample(self, batch_size)-> frozen_dict.FrozenDict:
N = len(self.buffer)
if N == self.capacity:
prios = self.priorities
else:
prios = self.priorities[:self.pos]
# calc P = p^a/sum(p^a)
probs = prios ** self.alpha
P = probs / probs.sum()
# gets the indices depending on the probability p
indices = np.random.choice(N, batch_size, p=P)
samples = [self.buffer[idx] for idx in indices]
beta = self.beta_by_frames(self.frame)
self.frame += 1
# Compute importance-sampling weight
weights = (N * P[indices]) ** (-beta)
# normalize weights
weights /= weights.max()
weights = np.array(weights, dtype=np.float32)
observations, actions, rewards, next_observations, dones = zip(*samples)
batch = {'observations': observations,
'actions': actions,
'rewards': rewards,
'next_observations': next_observations,
'dones': dones,
'indices': indices,
'weights': weights}
#print(batch['weights'])
for k,v in batch.items():
batch[k] = np.array(v)
return frozen_dict.freeze(batch)`
Please:
I want to pass it as an incoming parameter to the function:
But the update function requires jit acceleration
Error will be reported when I run directly
Thank you very much for your help!Thank you very much for your help!Thank you very much for your help!!!!!!