plurigrid / ontology

autopoietic ergodicity and embodied gradualism
https://vibes.lol
5 stars 9 forks source link

Prioritized Replay - Gromov Wasserstein Minibatch Sampling + Update Function #69

Open kennethZhangML opened 1 year ago

kennethZhangML commented 1 year ago
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import numpy as np

import scipy
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import cdist

import os 
import gym 
import random 
import queue
import collections 

class PrioritizedExperienceReplay:
    def __init__(self, buffer_size, alpha, beta, epsilon):
        self.buffer_size = buffer_size 
        self.alpha = alpha 
        self.beta = beta 
        self.epsilon = epsilon 

        self.buffer = collections.deque(maxlen = self.buffer_size)
        self.priorities = collections.deque(maxlen = self.buffer_size)
        self.priorities_sum = 0.0 
        self.max_priority = 1.0 

    def add_experience(self, experience, priority):
        self.buffer.append(experience)
        self.priorities.append(priority)
        self.priorities_sum += priority 

        if priority > self.max_priority:
            self.max_priority = priority 

    def minibatch(self, batch_size):
        priorities = np.array(self.priorities)
        probabilities = priorities ** self.alpha / self.priorities_sum 
        indices = np.random.choice(len(self.buffer), size = batch_size, p = probabilities)
        weights = (len(self.buffer) * probabilities[indices]) ** (-self.beta)
        weights /= np.max(weights)
        batch = [self.buffer[i] for i in indices]
        return batch, indices, weights 

    def update_priorities(self, indices, errors):
        for i, error in zip(indices, errors):
            priority = (error + self.epsilon) ** self.alpha 
            self.priorities_sum -= self.priorities[i]
            self.priorities[i] = priority 
            self.priorities_sum += priority 
            self.max_priority = max(self.max_priority, priority)

    def compute_cost_matrix(self, C, p, q):
        p_matrix = np.tile(p, len(q), 1).T
        q_matrix = np.tile(q, len(p), 1)
        M = C*p_matrix * q_matrix  
        return M

    def compute_GW_distance(self, X, Y, p, q):
        C = cdist(X, Y)
        M = self.compute_cost_matrix(C, p, q)
        row_ind, col_ind = linear_sum_assignment(M)
        GW_distance = M[row_ind, col_ind].sum() / p.sum()
        return GW_distance

    def sample_idx(self, batch_size):
        probabilities = np.array(self.priorities) ** self.alpha 
        probabilities /= np.sum(probabilities)
        idx = np.random.choice(len(self.buffer), size = batch_size, replace = False, p = probabilities)
        return idx 

    def minibatch_GW(self, batch_size, X, p):
        idx = self.sample_idx(batch_size)
        batch = [self.buffer[i] for i in idx] 
        Y = [experience.state for experience in batch]
        q = np.ones(len(Y)) / len(Y)
        GW_dist = self.compute_GW_distance(X, Y, p, q)
        return batch, idx, GW_dist 

    def update_priorities_GW(self, indx, gw_dists):
        for i, gw_dist in zip(indx, gw_dists):
            priority = (gw_dist + self.epsilon) ** self.alpha 
            self.priorities_sum -= self.priorities[i]
            self.priorities[i] = priority 
            self.priorities_sum += priority 
            self.max_priority = max(self.max_priority, priority)

    def get_max_priority(self):
        return self.max_priority 

    def __len__(self):
        return len(self.buffer)