Open kennethZhangML opened 1 year ago
import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import scipy from scipy.optimize import linear_sum_assignment from scipy.spatial.distance import cdist import os import gym import random import queue import collections class PrioritizedExperienceReplay: def __init__(self, buffer_size, alpha, beta, epsilon): self.buffer_size = buffer_size self.alpha = alpha self.beta = beta self.epsilon = epsilon self.buffer = collections.deque(maxlen = self.buffer_size) self.priorities = collections.deque(maxlen = self.buffer_size) self.priorities_sum = 0.0 self.max_priority = 1.0 def add_experience(self, experience, priority): self.buffer.append(experience) self.priorities.append(priority) self.priorities_sum += priority if priority > self.max_priority: self.max_priority = priority def minibatch(self, batch_size): priorities = np.array(self.priorities) probabilities = priorities ** self.alpha / self.priorities_sum indices = np.random.choice(len(self.buffer), size = batch_size, p = probabilities) weights = (len(self.buffer) * probabilities[indices]) ** (-self.beta) weights /= np.max(weights) batch = [self.buffer[i] for i in indices] return batch, indices, weights def update_priorities(self, indices, errors): for i, error in zip(indices, errors): priority = (error + self.epsilon) ** self.alpha self.priorities_sum -= self.priorities[i] self.priorities[i] = priority self.priorities_sum += priority self.max_priority = max(self.max_priority, priority) def compute_cost_matrix(self, C, p, q): p_matrix = np.tile(p, len(q), 1).T q_matrix = np.tile(q, len(p), 1) M = C*p_matrix * q_matrix return M def compute_GW_distance(self, X, Y, p, q): C = cdist(X, Y) M = self.compute_cost_matrix(C, p, q) row_ind, col_ind = linear_sum_assignment(M) GW_distance = M[row_ind, col_ind].sum() / p.sum() return GW_distance def sample_idx(self, batch_size): probabilities = np.array(self.priorities) ** self.alpha probabilities /= np.sum(probabilities) idx = np.random.choice(len(self.buffer), size = batch_size, replace = False, p = probabilities) return idx def minibatch_GW(self, batch_size, X, p): idx = self.sample_idx(batch_size) batch = [self.buffer[i] for i in idx] Y = [experience.state for experience in batch] q = np.ones(len(Y)) / len(Y) GW_dist = self.compute_GW_distance(X, Y, p, q) return batch, idx, GW_dist def update_priorities_GW(self, indx, gw_dists): for i, gw_dist in zip(indx, gw_dists): priority = (gw_dist + self.epsilon) ** self.alpha self.priorities_sum -= self.priorities[i] self.priorities[i] = priority self.priorities_sum += priority self.max_priority = max(self.max_priority, priority) def get_max_priority(self): return self.max_priority def __len__(self): return len(self.buffer)