mit-han-lab / torchquantum

A PyTorch-based framework for Quantum Classical Simulation, Quantum Machine Learning, Quantum Neural Networks, Parameterized Quantum Circuits with support for easy deployments on real quantum computers.
https://torchquantum.org
MIT License
1.24k stars 185 forks source link

Torchquantum reinforcement learning agent behaves randomly even after epsilon-greedy phase #124

Open JustinS6626 opened 1 year ago

JustinS6626 commented 1 year ago

First of all, thanks again so much for your help with the parameter updating problem that I reported a few weeks ago. I have now encountered another issue. What seems to be happening is that even after the epsilon-greedy phase is over, the agent still behaves randomly. Since I deactivated randomness for the linear and convolutional layers, I am wondering what the source of the problem is. My code as it is right now is provided below:

import math
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn as nn
import torch.nn.functional as F
import time
import datetime
import calendar
import random
from minigrid.wrappers import *
import logging
from torchpack.callbacks import (InferenceRunner, MeanAbsoluteError,
                                 MaxSaver, MinSaver,
                                 Saver, SaverRestore, CategoricalAccuracy)
from torchpack.environ import set_run_dir
from torchpack.utils.config import configs
from torchpack.utils.logging import logger
from torchtest import assert_vars_change
from torch.nn.parameter import Parameter
import torchquantum as tq
import torchquantum.functional as tqf
from torchquantum.measurement import *
import matplotlib.pyplot as plt
import pickle as pkl
from obs_wrappers import ImgObsFlatWrapper
import gymnasium as gym
from gymnasium.wrappers.record_video import RecordVideo
from collections import namedtuple, deque
#from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistic
import numpy as np
import os
from gymnasium.envs.registration import *
##from ._utils import _import_dotted_name
##from ._six import string_classes as _string_classes
##from torch._sources import get_source_lines_and_file
##from torch.types import Storage
##from torch.storage import _get_dtype_from_pickle_storage_type
##from typing_extensions import TypeAlias
##import copyreg

class ReplayMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def output_all(self):
        return self.memory

    def __len__(self):
        return len(self.memory)

Transition = namedtuple('Transition',
                        ('state', 'action', 'reward', 'next_state', 'done'))

def SO4(q_device, RY, RZ, CNOT, wires, static=None, parent_graph=None):
##    rz_pi = np.asarray([[np.exp(-1j * (np.pi / 4)), 0],
##                        [0, np.exp(1j * (np.pi / 4))]])
##    rz_neg_pi = np.asarray([[np.exp(1j * (np.pi / 4)), 0],
##                        [0, np.exp(-1j * (np.pi / 4))]])
##    ry_pi = np.asarray([[np.cos(np.pi / 4), -1 * np.sin(np.pi / 4)],
##                        [np.sin(np.pi / 4), np.cos(np.pi / 4)]])
##    ry_neg_pi = np.asarray([[np.cos(-1 * np.pi / 4), -1 * np.sin(-1 * np.pi / 4)],
##                            [np.sin(-1 * np.pi / 4), np.cos(-1 * np.pi / 4)]])
##    tqf.qubitunitary(device, wires=wires[0], params=rz_pi)
##    tqf.qubitunitary(device, wires=wires[1], params=rz_pi)
##    tqf.qubitunitary(device, wires=wires[1], params=ry_pi)
    tqf.rz(q_device, wires=wires[0], params=torch.tensor([np.pi / 2]), static=static_mode, parent_graph=graph)
    tqf.rz(q_device, wires=wires[1], params=torch.tensor([np.pi / 2]), static=static_mode, parent_graph=graph)
    tqf.ry(q_device, wires=wires[1], params=torch.tensor([np.pi / 2]), static=static_mode, parent_graph=graph)
    tqf.cnot(q_device, wires=[wires[1], wires[0]], static=static_mode, parent_graph=graph)
    RZ[0](q_device, wires=wires[0])
    RZ[1](q_device, wires=wires[1])
    RY[0](q_device, wires=wires[0])
    RY[1](q_device, wires=wires[1])
    RZ[2](q_device, wires=wires[0])
    RZ[3](q_device, wires=wires[1])
    tqf.cnot(q_device, wires=[wires[1], wires[0]], static=static_mode, parent_graph=graph)
    tqf.ry(q_device, wires=wires[1], params=torch.tensor([-np.pi / 2]), static=static_mode, parent_graph=graph)
    tqf.rz(q_device, wires=wires[0], params=torch.tensor([-np.pi / 2]), static=static_mode, parent_graph=graph)
    tqf.rz(q_device, wires=wires[1], params=torch.tensor([-np.pi / 2]), static=static_mode, parent_graph=graph)
##    tqf.qubitunitary(device, wires=wires[1], params=ry_neg_pi)
##    tqf.qubitunitary(device, wires=wires[0], params=rz_neg_pi)
##    tqf.qubitunitary(device, wires=wires[1], params=rz_neg_pi)

class TreeTensorAgent(tq.QuantumModule):
    class QLayer(tq.QuantumModule):
        def __init__(self):
            super().__init__()
            self.n_wires = 8
            self.n_actions = 4
##                self.q_device = tq.QuantumDevice(n_wires=self.n_wires)

            #self.bias = torch.tensor(np.random.rand(4), requires_grad=True)
            self.rz_0_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_0_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_0_0 = tq.RY(has_params=True, trainable=True)
            self.ry_0_1 = tq.RY(has_params=True, trainable=True)
            self.rz_0_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_0_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_1_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_1_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_1_0 = tq.RY(has_params=True, trainable=True)
            self.ry_1_1 = tq.RY(has_params=True, trainable=True)
            self.rz_1_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_1_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_2_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_2_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_2_0 = tq.RY(has_params=True, trainable=True)
            self.ry_2_1 = tq.RY(has_params=True, trainable=True)
            self.rz_2_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_2_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_3_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_3_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_3_0 = tq.RY(has_params=True, trainable=True)
            self.ry_3_1 = tq.RY(has_params=True, trainable=True)
            self.rz_3_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_3_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_4_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_4_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_4_0 = tq.RY(has_params=True, trainable=True)
            self.ry_4_1 = tq.RY(has_params=True, trainable=True)
            self.rz_4_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_4_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_5_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_5_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_5_0 = tq.RY(has_params=True, trainable=True)
            self.ry_5_1 = tq.RY(has_params=True, trainable=True)
            self.rz_5_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_5_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_6_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_6_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_6_0 = tq.RY(has_params=True, trainable=True)
            self.ry_6_1 = tq.RY(has_params=True, trainable=True)
            self.rz_6_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_6_3 = tq.RZ(has_params=True, trainable=True)
            self.cnot = tq.CNOT(has_params=False, trainable=False)

        def forward(self, q_device, static_mode, graph):
            self.q_device = q_device
            #SO4(self.q_device, [self.ry_0_0, self.ry_0_1], [self.rz_0_0, self.rz_0_1, self.rz_0_2, self.rz_0_3], self.cnot, [0, 1], static=static_mode_mode, parent_graph=graph)
            #Layer 1 Gate 1 Start
            tqf.rz(q_device, wires=0, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=1, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.ry(q_device, wires=1, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.cnot(q_device, wires=[1, 0], static=static_mode)
            self.rz_0_0(q_device, wires=0)
            self.rz_0_1(q_device, wires=1)
            self.ry_0_0(q_device, wires=0)
            self.ry_0_1(q_device, wires=1)
            self.rz_0_2(q_device, wires=0)
            self.rz_0_3(q_device, wires=1)
            tqf.cnot(q_device, wires=[1, 0], static=static_mode)
            tqf.ry(q_device, wires=1, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=0, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=1, params=torch.tensor([-np.pi / 2]), static=static_mode)
            #Layer 1 Gate 1 End

            #SO4(self.q_device, [self.ry_1_0, self.ry_1_1], [self.rz_1_0, self.rz_1_1, self.rz_1_2, self.rz_1_3], self.cnot, [2, 3], static=static_mode_mode, parent_graph=graph)
            #Layer 1 Gate 2 Start
            tqf.rz(q_device, wires=2, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=3, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.ry(q_device, wires=3, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.cnot(q_device, wires=[3, 2], static=static_mode)
            self.rz_1_0(q_device, wires=2)
            self.rz_1_1(q_device, wires=3)
            self.ry_1_0(q_device, wires=2)
            self.ry_1_1(q_device, wires=3)
            self.rz_1_2(q_device, wires=2)
            self.rz_1_3(q_device, wires=3)
            tqf.cnot(q_device, wires=[3, 2], static=static_mode)
            tqf.ry(q_device, wires=3, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=2, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=3, params=torch.tensor([-np.pi / 2]), static=static_mode)
            #Layer 1 Gate 2 End

            #SO4(self.q_device, [self.ry_2_0, self.ry_2_1], [self.rz_2_0, self.rz_2_1, self.rz_2_2, self.rz_2_3], self.cnot, [4, 5], static=static_mode_mode, parent_graph=graph)
            #Layer 1 Gate 3 Start
            tqf.rz(q_device, wires=4, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=5, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.ry(q_device, wires=5, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.cnot(q_device, wires=[5, 4], static=static_mode)
            self.rz_2_0(q_device, wires=4)
            self.rz_2_1(q_device, wires=5)
            self.ry_2_0(q_device, wires=4)
            self.ry_2_1(q_device, wires=5)
            self.rz_2_2(q_device, wires=4)
            self.rz_2_3(q_device, wires=5)
            tqf.cnot(q_device, wires=[5, 4], static=static_mode)
            tqf.ry(q_device, wires=5, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=4, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=5, params=torch.tensor([-np.pi / 2]), static=static_mode)
            #Layer 1 Gate 3 End

            #SO4(self.q_device, [self.ry_3_0, self.ry_3_1], [self.rz_3_0, self.rz_3_1, self.rz_3_2, self.rz_3_3], self.cnot, [6, 7], static=static_mode_mode, parent_graph=graph)
            #Layer 1 Gate 4 Start
            tqf.rz(q_device, wires=6, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=7, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.ry(q_device, wires=7, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.cnot(q_device, wires=[7, 6], static=static_mode)
            self.rz_3_0(q_device, wires=6)
            self.rz_3_1(q_device, wires=7)
            self.ry_3_0(q_device, wires=6)
            self.ry_3_1(q_device, wires=7)
            self.rz_3_2(q_device, wires=6)
            self.rz_3_3(q_device, wires=7)
            tqf.cnot(q_device, wires=[7, 6], static=static_mode)
            tqf.ry(q_device, wires=7, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=6, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=7, params=torch.tensor([-np.pi / 2]), static=static_mode)
            #Layer 1 Gate 4 End

            #SO4(self.q_device, [self.ry_4_0, self.ry_4_1], [self.rz_4_0, self.rz_4_1, self.rz_4_2, self.rz_4_3], self.cnot, [1, 2], static=static_mode_mode, parent_graph=graph)
            #Layer 2 Gate 1 Start
            tqf.rz(q_device, wires=1, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=2, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.ry(q_device, wires=2, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.cnot(q_device, wires=[2, 1], static=static_mode)
            self.rz_4_0(q_device, wires=1)
            self.rz_4_1(q_device, wires=2)
            self.ry_4_0(q_device, wires=1)
            self.ry_4_1(q_device, wires=2)
            self.rz_4_2(q_device, wires=1)
            self.rz_4_3(q_device, wires=2)
            tqf.cnot(q_device, wires=[2, 1], static=static_mode)
            tqf.ry(q_device, wires=2, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=1, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=2, params=torch.tensor([-np.pi / 2]), static=static_mode)
            #Layer 2 Gate 1 End

            #SO4(self.q_device, [self.ry_5_0, self.ry_5_1], [self.rz_5_0, self.rz_5_1, self.rz_5_2, self.rz_5_3], self.cnot, [5, 6], static=static_mode_mode, parent_graph=graph)
            #Layer 2 Gate 2 Start
            tqf.rz(q_device, wires=5, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=6, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.ry(q_device, wires=6, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.cnot(q_device, wires=[6, 5], static=static_mode)
            self.rz_5_0(q_device, wires=5)
            self.rz_5_1(q_device, wires=6)
            self.ry_5_0(q_device, wires=5)
            self.ry_5_1(q_device, wires=6)
            self.rz_5_2(q_device, wires=5)
            self.rz_5_3(q_device, wires=6)
            tqf.cnot(q_device, wires=[6, 5], static=static_mode)
            tqf.ry(q_device, wires=6, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=5, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=6, params=torch.tensor([-np.pi / 2]), static=static_mode)
            #Layer 2 Gate 2 End
            #SO4(self.q_device, [self.ry_6_0, self.ry_6_1], [self.rz_6_0, self.rz_6_1, self.rz_6_2, self.rz_6_3], self.cnot, [2, 5], static=static_mode_mode, parent_graph=graph)
            #Layer 3 Gate 1 Start
            tqf.rz(q_device, wires=2, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=5, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.ry(q_device, wires=5, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.cnot(q_device, wires=[5, 2], static=static_mode)
            self.rz_6_0(q_device, wires=0)
            self.rz_6_1(q_device, wires=1)
            self.ry_6_0(q_device, wires=0)
            self.ry_6_1(q_device, wires=1)
            self.rz_6_2(q_device, wires=0)
            self.rz_6_3(q_device, wires=1)
            tqf.cnot(q_device, wires=[5, 2], static=static_mode, parent_graph=graph)
            tqf.ry(q_device, wires=5, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=2, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=5, params=torch.tensor([-np.pi / 2]), static=static_mode)
            #Layer 3 Gate 1 End
    def __init__(self, input_size):
        super().__init__()
        self.n_wires = 8
        self.n_actions = 4
        self.input_size = input_size
        self.q_layer = self.QLayer()
        #self.bias = Parameter(torch.zeros(self.n_actions))
        self.smx = nn.Softmax()
        self.bitstrings = gen_bitstrings(self.n_wires)
        self.q_device = tq.QuantumDevice(n_wires=self.n_wires)
        #self.mps = MPS(input_dim = 147, output_dim = 8, bond_dim = 2, feature_dim = 2, use_GPU = False, parallel = True, init_std=1e-2)
##        self.feature_map = nn.Sequential(nn.Linear(self.input_size, 64), nn.ReLU(),
##                                         nn.Conv1d(self.input_size, 64, kernel_size=2, stride=2), nn.ReLU(),
##                                         nn.Conv1d(64, 1, kernel_size=2, stride=3), nn.Tanh())
        self.layer_1 = nn.Linear(1,256)
        #147 x 128
        self.layer_2 = nn.ReLU()
        self.layer_3 = nn.Conv1d(self.input_size, 256, kernel_size=2, padding=1, dilation=2)
        self.layer_4 = nn.ReLU()
        self.layer_5 = nn.Conv1d(256, 1, kernel_size=4, stride=8, padding=1, dilation=3)
        self.layer_6 = nn.CELU()
##        for param in self.mps.parameters():
##            print("Gradient required: " + str(param.requires_grad))

##        self.measure = tq.MeasureMultiPauliSum(
##            obs_list=[{"wires" : [3, 4],
##                       "observables" : ['z', 'z'],
##                       "coefficient" : [1, 1]}])

##        self.encoder1 = tq.GeneralEncoder(
##            [{"input_idx" : [0], "func" : "ry", "wires" : [0]},
##             {"input_idx" : [1], "func" : "ry", "wires" : [1]},
##             {"input_idx" : [2], "func" : "ry", "wires" : [2]},
##             {"input_idx" : [3], "func" : "ry", "wires" : [3]},
##             {"input_idx" : [4], "func" : "ry", "wires" : [4]},
##             {"input_idx" : [5], "func" : "ry", "wires" : [5]},
##             {"input_idx" : [6], "func" : "ry", "wires" : [6]},
##             {"input_idx" : [7], "func" : "ry", "wires" : [7]}])
##
##        self.encoder2 = tq.GeneralEncoder(
##            [{"input_idx" : [0], "func" : "rz", "wires" : [0]},
##             {"input_idx" : [1], "func" : "rz", "wires" : [1]},
##             {"input_idx" : [2], "func" : "rz", "wires" : [2]},
##             {"input_idx" : [3], "func" : "rz", "wires" : [3]},
##             {"input_idx" : [4], "func" : "rz", "wires" : [4]},
##             {"input_idx" : [5], "func" : "rz", "wires" : [5]},
##             {"input_idx" : [6], "func" : "rz", "wires" : [6]},
##             {"input_idx" : [7], "func" : "rz", "wires" : [7]}])

        self.encoder=tq.GeneralEncoder(
            [{"input_idx" : [0], "func" : "ry", "wires" : [0]},
             {"input_idx" : [1], "func" : "ry", "wires" : [1]},
             {"input_idx" : [2], "func" : "ry", "wires" : [2]},
             {"input_idx" : [3], "func" : "ry", "wires" : [3]},
             {"input_idx" : [4], "func" : "ry", "wires" : [4]},
             {"input_idx" : [5], "func" : "ry", "wires" : [5]},
             {"input_idx" : [6], "func" : "ry", "wires" : [6]},
             {"input_idx" : [7], "func" : "ry", "wires" : [7]},
             {"input_idx" : [8], "func" : "rz", "wires" : [0]},
             {"input_idx" : [9], "func" : "rz", "wires" : [1]},
             {"input_idx" : [10], "func" : "rz", "wires" : [2]},
             {"input_idx" : [11], "func" : "rz", "wires" : [3]},
             {"input_idx" : [12], "func" : "rz", "wires" : [4]},
             {"input_idx" : [13], "func" : "rz", "wires" : [5]},
             {"input_idx" : [14], "func" : "rz", "wires" : [6]},
             {"input_idx" : [15], "func" : "rz", "wires" : [7]},
             {"input_idx" : [16], "func" : "rx", "wires" : [0]},
             {"input_idx" : [17], "func" : "rx", "wires" : [1]},
             {"input_idx" : [18], "func" : "rx", "wires" : [2]},
             {"input_idx" : [19], "func" : "rx", "wires" : [3]},
             {"input_idx" : [20], "func" : "rx", "wires" : [4]},
             {"input_idx" : [21], "func" : "rx", "wires" : [5]},
             {"input_idx" : [22], "func" : "rx", "wires" : [6]},
             {"input_idx" : [23], "func" : "rx", "wires" : [7]},
             {"input_idx" : [24], "func" : "rz", "wires" : [0]},
             {"input_idx" : [25], "func" : "rz", "wires" : [1]},
             {"input_idx" : [26], "func" : "rz", "wires" : [2]},
             {"input_idx" : [27], "func" : "rz", "wires" : [3]},
             {"input_idx" : [28], "func" : "rz", "wires" : [4]},
             {"input_idx" : [29], "func" : "rz", "wires" : [5]},
             {"input_idx" : [30], "func" : "rz", "wires" : [6]},
             {"input_idx" : [31], "func" : "rz", "wires" : [7]}])

##    def get_angles_atan(self, in_x):
##        angles = torch.stack([torch.stack([torch.atan(item), torch.atan(item**2)]) for item in in_x])
##        return angles

    def forward(self, input_data, check=False):
        #measure_counts = np.zeros(self.n_actions)
        prob_dict = {}
        #x = self.feature_map(input_data)
        x_1 = self.layer_1(input_data)
        x_2 = self.layer_2(x_1)
        #print("Stage one size: " + str(x_2.shape))
        x_3 = self.layer_3(x_2)
        x_4 = self.layer_4(x_3)
        #print("Stage two size: " + str(x_3.shape))
        x_5 = self.layer_5(x_4)
        x_6 = self.layer_6(x_5)
        #print("Stage three size " + str(x_6.shape))
        #print(type(x))
        #print(x.shape)
        #x_angles = self.get_angles_atan(x)
        #x_angles = torch.stack([torch.atan(x), torch.atan(x ** 2)])
##        print("Angle array shape: " + str(x_angles.shape))
##        print("Gradient preserved: " + str(x_angles.requires_grad))
        #torch.reshape(x_angles, (1, 16))
##        new_x_angles = x_angles.view(1, 16)
##        x_angles = new_x_angles
##        print("Input shape: " + str(x_angles.shape))
##        print("Gradient preserved: " + str(x_angles.requires_grad))
##        if check:
##            print(x_angles)
        #print(x_angles[0][0])
        x_angles = torch.atan(x_6)
        for i in range(self.n_wires):

            tqf.hadamard(self.q_device, wires=i, static=self.static_mode, parent_graph=self.graph)

##        self.encoder1(self.q_device, x_angles[0][0])
##        self.encoder2(self.q_device, x_angles[1][0])
        self.encoder(self.q_device, x_angles)
        #print("Parent graph: " + str(self.graph))
        self.q_layer.forward(self.q_device, self.static_mode, self.graph)

##        SO4(self.q_device, [self.ry_0_0, self.ry_0_1], [self.rz_0_0, self.rz_0_1, self.rz_0_2, self.rz_0_3], self.cnot, [0, 1])
##        SO4(self.q_device, [self.ry_1_0, self.ry_1_1], [self.rz_1_0, self.rz_1_1, self.rz_1_2, self.rz_1_3], self.cnot, [2, 3])
##        SO4(self.q_device, [self.ry_2_0, self.ry_2_1], [self.rz_2_0, self.rz_2_1, self.rz_2_2, self.rz_2_3], self.cnot, [4, 5])
##        SO4(self.q_device, [self.ry_3_0, self.ry_3_1], [self.rz_3_0, self.rz_3_1, self.rz_3_2, self.rz_3_3], self.cnot, [6, 7])
##        SO4(self.q_device, [self.ry_4_0, self.ry_4_1], [self.rz_4_0, self.rz_4_1, self.rz_4_2, self.rz_4_3], self.cnot, [1, 2])
##        SO4(self.q_device, [self.ry_5_0, self.ry_5_1], [self.rz_5_0, self.rz_5_1, self.rz_5_2, self.rz_5_3], self.cnot, [5, 6])
##        SO4(self.q_device, [self.ry_6_0, self.ry_6_1], [self.rz_6_0, self.rz_6_1, self.rz_6_2, self.rz_6_3], self.cnot, [2, 5])
        #print("ops done")
##        device_states = self.q_device.get_states_1d()
##        #print(device_states)
##        circuit_state = tq.QuantumState(n_wires=self.n_wires)
##        circuit_state.set_states(device_states)
        #state_vec = self.q_device.get_states_1d().abs().detach().cpu().numpy()
##        print("State vector: ")
##        print(state_vec)
##        print(state_vec.shape)
        #measures = tq.measure(self.q_device, n_shots=4096)
##        for i in range(len(self.bitstrings)):
##            prob_dict[self.bitstrings[i]] = np.abs(state_vec[0][i]) ** 2
##        qbit_states = list(prob_dict.keys())
        #print(qbit_states)
##        print(type(measure_results))
        #print(measure_results)
        #qbit_states = [result.keys() for result in measure_results]
        #print(qbit_states)
##        for bitkey in qbit_states:
##            if bitkey[3] == '0' and bitkey[4] == '0':
##                measure_counts[0] += prob_dict[bitkey]
##            elif bitkey[3] == '0' and bitkey[4] == '1':
##                measure_counts[1] += prob_dict[bitkey]
##            elif bitkey[3] == '1' and bitkey[4] == '0':
##                measure_counts[2] += prob_dict[bitkey]
##            else:
##                measure_counts[3] += prob_dict[bitkey]
##        measure_norm = np.linalg.norm(measure_counts)
##        measure_counts = measure_counts / measure_norm
##        if check:
##            print("Measure outcomes: ")
##            print(measure_counts)
        #measure_weights = torch.tensor(measure_counts, requires_grad=True)
        obs_1 = expval_joint_analytical(self.q_device, "ZZZXXZZZ")
        obs_2 = expval_joint_analytical(self.q_device, "ZZZYYZZZ")
        obs_3 = expval_joint_analytical(self.q_device, "ZZZYXZZZ")
        obs_4 = expval_joint_analytical(self.q_device, "ZZZXYZZZ")
        expectations = torch.stack([obs_1, obs_2, obs_3, obs_4], dim=1)
        #measure_weights = self.smx(measure_results)
        if check:
            print("Measure weights: ")
            print(expectations)
        #print("Gradient preserved: " + str(measure_weights.requires_grad))
        measure_weights = expectations.view(4)
        #print("Output shape: " + str(measure_weights.shape))
##        if check:
##            print(measure_weights)
        #print("Measure results")
        #print(measure_counts)
        return measure_weights

def square_loss(labels, predictions):
    loss = 0
    for l, p in zip(labels, predictions):
##        print(type(l))
##        print(type(p))
        loss = loss + ((l - p) ** 2)
    loss = loss / len(labels)
    #print(type(loss))
    return loss

def epsilon_greedy(TreeTensor, epsilon, s, n_actions, timestep, rgen, check=False, train=False):
##    seed = int(time.time())
##    rng = np.random.default_rng(seed)
    if train or rgen.random() < ((epsilon / n_actions) + (1 - epsilon)):
        with torch.no_grad():
            measurements = TreeTensor(s, check=check)
            action = torch.argmax(measurements)
            if check:
                print("Argmax result: " + str(action))
            return action

        #print("Circuit")
    else:
##        seedval = int(time.time())
##        np.random.seed(seedVal)
        action = rgen.integers(0, high=n_actions)
##        if check:
##            print(choices)
        #action = np.bincount(choices).argmax()
##        if check:
##            print(action)

        print("Epsilon")
        action = torch.tensor(action)
        return action

def cost(model, features, labels, dev):
    #print(features)
    loss_func = nn.SmoothL1Loss()
    predictions = [model(item.state)[item.action] for item in features]
    loss_total = loss_func(torch.tensor(labels, requires_grad=True, device=dev), torch.tensor(predictions, requires_grad=True, device=dev))
    return loss_total

def ttn_train(env_name, model, alpha, gamma, epsilon, episodes, max_steps, n_actions, top_dev, opt, sched, render=True):

    act_range = [0, 1, 2, 6]
    logging.basicConfig(filename="ExperimentDebug1.txt", level=logging.DEBUG)
    logging.captureWarnings(True)
##    use_cuda = torch.cuda.is_available()
##    main_device = torch.device("cuda" if use_cuda else "cpu")
##    circuit = model.to(main_device)

    param_file = "TTN_params.bin"
    scores = []
    target_update = 20
    batch_size = 100
    optimize_steps = 5
    target_update_counter = 0
    iter_index = []
    iter_reward = []
    iter_total_steps = []
    cost_list = []
    timestep_reward = []
    random.seed(int(time.time()))
    seed = int(time.time())
    rng = np.random.default_rng(seed)
    memory = ReplayMemory(500)
##    seedVal = int(time.time())
##    np.random.seed(seedVal)
    #q_device = tq.QuantumDevice(n_wires=8)
    #print(type(q_device))
    #optimizer = optim.Adam(model.parameters(), lr=alpha, weight_decay=1e-4)
    #optimizer_mps = optim.Adam(model.mps.parameters(), lr=alpha, weight_decay=1e-4)
##    optimizer = optim.SGD(model.parameters(), lr=alpha, momentum=0.9)
##    optimizer_mps = optim.SGD(model.parameters(), lr=alpha, momentum=0.9)
    #scheduler = CosineAnnealingLR(optimizer, T_max=episodes)
    #scheduler_mps = CosineAnnealingLR(optimizer_mps, T_max=episodes)
    env = gym.make(env_name, max_episode_steps=max_steps, disable_env_checker=True, render_mode="human")
    #print(type(env))
    env = ImgObsFlatWrapper(env)
    env_record = RecordVideo(env, f"video/TTNMinigridTraining")
##    record_dict = env_record.__dict__
##    record_keys = record_dict.keys()
##    print(record_keys)
    start_state = None
    start_time = time.asctime()
    for episode in range(episodes):
        env_record.reset()
        print("Episode: " + str(episode))
        t = 0
        total_reward = 0
        #print("Reset reward")
        done = False
        #print("Not done")
        #rgen = np.random.RandomState(seedVal)
        seedVal = int(time.time())
        np.random.seed(seedVal)
##        print(type(observation))
##        print(observation)
        #print("Reset complete")
        if episode == 0:
            start_state = env_record.env.grid
            #print("Start state set")
        else:
            env_record.env.grid = start_state
            #print("Start state retrieved")
        if render:
            env_record.render()
        #print("Number of obstacles: " + str(len(env_record.env.obstacles)))
        observation = env_record.env.gen_obs()
        #print("Got observation")
        #print(type(observation))
#        print(observation)
        observation = torch.tensor(observation['image']).type('torch.FloatTensor').view(147, 1).to(top_dev)
        print("Observation shape: " + str(observation.shape))
        #print("Observation formatted")
        #print(observation)
        observation.requires_grad = True
        #observation = observation.to(top_dev)
        act = epsilon_greedy(model, epsilon, observation, n_actions, t, rng, check=True)
        #print("Action index: " + str(act))
##        print("Got action")
        action = act_range[act]
        #print("Action selection: " + str(action))
        #scores.append(total_reward)
        while t < max_steps:
            print("Episode: " + str(episode) + " , " + "Timestep: " + str(t))
            if render:
                env_record.render()
            #print("Time Step: " + str(t))
            t += 1
            target_update_counter += 1
            seedVal = int(time.time())
            np.random.seed(seedVal)
            next_obs, reward, done, _, info = env_record.step(action)
            print("Step reward: " + str(reward))
            #print("Step reward: " + str(reward))
            #print(type(next_obs))
            next_obs = torch.tensor(next_obs).type('torch.FloatTensor').view(147, 1).to(top_dev)
            next_obs.requires_grad = True
            #new_obs = next_obs.to(main_device)
            total_reward += reward
            act_ = epsilon_greedy(model, epsilon, next_obs, n_actions, t, rng, check=True)
            #print("Action index: " + str(act))
            action_ = act_range[act_]
            #print("Action selection: " + str(action))
            memory.push(observation, act, reward, next_obs, done)
            if len(memory) > batch_size and done:
                batch_sampled = memory.sample(batch_size)
##                batch = Transition(*zip(*transitions))
##                non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool)
##                non_final_next_states = torch.cat([s for s in batch.next_state
##                                                if s is not None])
##                state_batch = torch.cat(batch.state)
##                action_batch = torch.cat(batch.action)
##                reward_batch = torch.cat(batch.reward)
                Qtarget = [item.reward + (1 - int(item.done)) * gamma * torch.max(model(item.next_state)) for item in batch_sampled]
                loss = cost(model, batch_sampled, Qtarget, top_dev)
                #print("Loss Gradient Function: " + str(loss.grad_fn))
##                grads = torch.autograd.grad(loss, list(model.parameters()), allow_unused=True)
##                print("Grad type: " + str(grads))
                #print("Loss: " + str(loss))

                #optimizer_mps.zero_grad()
                opt.zero_grad()
                    #print(loss)
                    #print(type(loss))
##                loss.backward()
##                for param in model.parameters():
##                    print("Parameter gradient: " + str(param.grad))
                opt.step()
                #optimizer_mps.step()
                #print(model.parameters())
                #print("Optimization step")
            #scheduler.step()
            #scheduler_mps.step()
##            current_replay_memory = memory.output_all()
##            current_target_for_replay_memory = [item.reward + (1 - int(item.done)) * gamma * torch.max(model(item.next_state)) for item in current_replay_memory]
            if target_update_counter >= target_update:
                target_update_counter = 0
            observation, action = next_obs, action_
            if done or t == max_steps:
                epsilon = epsilon / ((episode / 750) + 1)
                alpha = 0.95 * alpha
                timestep_reward.append(total_reward)
                print("Reward data length: " + str(len(timestep_reward)))
                iter_index.append(episode)
                iter_total_steps.append(t)
                break
    stop_time = time.asctime()
    print("Start time: ")
    print(start_time)
    print("Stop time: ")
    print(stop_time)
    torch.save(model.state_dict(), param_file)
    return timestep_reward, iter_index, iter_reward, iter_total_steps

def test_agent(model, env_folder, epsilon, env_name, config_name, n_tests, max_steps, delay=1):
    act_range = [0, 1, 2, 6]
    n_successes = 0
    test_rewards = []
##    use_cuda = torch.cuda.is_available()
##    main_device = torch.device("cuda" if use_cuda else "cpu")
##    circuit = model.to(main_device)
    env = gym.make(env_name, max_episode_steps=max_steps, render_mode="human", height=64, width=64)
    env = SymbolicObsWrapper(env)
    env = ImgObsFlatWrapper(env)
    env_record = RecordVideo(env, f"video/TTNMinigridTraining")
    done = False
    for test in range(n_tests):
        reward_total = 0
        epsilon = 0
        env.reset()
        filename = env_folder + "/" + env_name + "_" + str(test)
        statefile = open(filename, "wb")
        state_data = pkl.load(statefile)
        new_grid = env_record.env.grid.decode(state_data)
        env_record.env.grid = new_grid
        while True:
            time.sleep(delay)
            s = torch.tensor(observation).type('torch.FloatTensor').view(1, -1)
            act = epsilon_greedy(model, epsilon, observation)
            a = act_range[act]
            next_obs, reward, done, info = env_record.step(a)
            next_obs = torch.tensor(next_obs).type('torch.FloatTensor').view(1, -1)
            reward_total += reward
            if done:
                if reward > 0:
                    n_successes += 1
                    print("Goal Reached")
                else:
                    print("Task Failed")
                test_rewards.append(reward_total)
                time.sleep(3)
                break
    return test_rewards, n_sucesses

def main():
##    register(
##        id="Minigrid-RandomLava-6Spots-v0",
##        entry_point="RandomLavaMinigrid:RandomLavaEnv",
##        kwargs={"size": 8, "n_obstacles": 6})
    #env_name = "Minigrid-RandomLava-6Spots-v0"
    env_name = "MiniGrid-Empty-8x8-v0"
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    torch.use_deterministic_algorithms(True)
    torch.backends.cudnn.benchmark = False
    #device = torch.device("cpu")
    alpha = 0.4
    gamma = 0.5
    epsilon = 1
    episodes = 1000
    max_steps = 100
    n_actions = 4
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    model = TreeTensorAgent(147).to(device)
    print("Model graph: " + str(model.graph))

    tn_opt = optim.Adam(model.parameters(), lr=5e-3, weight_decay=1e-4)
    scheduler = CosineAnnealingLR(tn_opt, T_max=episodes)
    timestep_reward, iter_index, iter_reward, iter_total_steps = ttn_train(env_name, model, alpha, gamma, epsilon, episodes, max_steps, n_actions, device, tn_opt, scheduler)
    x_vals = np.arange(episodes)
    y_vals = np.asarray(timestep_reward)
    fig, ax = plt.subplots()
    ax.plot(x_vals, y_vals)
    ax.grid()
    ax.set(xlabel="Episode", ylabel="Total Score", title="Deep Quantum TTN Learning Training Process: 6-Site Random Lava Minigrid")
    fig.savefig("TTNTrain.png")
    plt.close(fig)

if __name__ == "__main__":
    main()

I would greatly appreciate your help.

JustinS6626 commented 1 year ago

I just notice that the code that I posted has a bad import. This one should replicate the issue:

import math
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn as nn
import torch.nn.functional as F
import time
import datetime
import calendar
import random
from minigrid.wrappers import *
import logging
from torchpack.callbacks import (InferenceRunner, MeanAbsoluteError,
                                 MaxSaver, MinSaver,
                                 Saver, SaverRestore, CategoricalAccuracy)
from torchpack.environ import set_run_dir
from torchpack.utils.config import configs
from torchpack.utils.logging import logger
from torchtest import assert_vars_change
from torch.nn.parameter import Parameter
import torchquantum as tq
import torchquantum.functional as tqf
from torchquantum.measurement import *
import matplotlib.pyplot as plt
import pickle as pkl
import gymnasium as gym
from gymnasium.wrappers.record_video import RecordVideo
from collections import namedtuple, deque
#from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistic
import numpy as np
import os
from gymnasium.envs.registration import *
##from ._utils import _import_dotted_name
##from ._six import string_classes as _string_classes
##from torch._sources import get_source_lines_and_file
##from torch.types import Storage
##from torch.storage import _get_dtype_from_pickle_storage_type
##from typing_extensions import TypeAlias
##import copyreg

#os.environ["SDL_VIDEODRIVER"] = "dummy"

class ReplayMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def output_all(self):
        return self.memory

    def __len__(self):
        return len(self.memory)

Transition = namedtuple('Transition',
                        ('state', 'action', 'reward', 'next_state', 'done'))

def SO4(q_device, RY, RZ, CNOT, wires, static=None, parent_graph=None):
##    rz_pi = np.asarray([[np.exp(-1j * (np.pi / 4)), 0],
##                        [0, np.exp(1j * (np.pi / 4))]])
##    rz_neg_pi = np.asarray([[np.exp(1j * (np.pi / 4)), 0],
##                        [0, np.exp(-1j * (np.pi / 4))]])
##    ry_pi = np.asarray([[np.cos(np.pi / 4), -1 * np.sin(np.pi / 4)],
##                        [np.sin(np.pi / 4), np.cos(np.pi / 4)]])
##    ry_neg_pi = np.asarray([[np.cos(-1 * np.pi / 4), -1 * np.sin(-1 * np.pi / 4)],
##                            [np.sin(-1 * np.pi / 4), np.cos(-1 * np.pi / 4)]])
##    tqf.qubitunitary(device, wires=wires[0], params=rz_pi)
##    tqf.qubitunitary(device, wires=wires[1], params=rz_pi)
##    tqf.qubitunitary(device, wires=wires[1], params=ry_pi)
    tqf.rz(q_device, wires=wires[0], params=torch.tensor([np.pi / 2]), static=static_mode, parent_graph=graph)
    tqf.rz(q_device, wires=wires[1], params=torch.tensor([np.pi / 2]), static=static_mode, parent_graph=graph)
    tqf.ry(q_device, wires=wires[1], params=torch.tensor([np.pi / 2]), static=static_mode, parent_graph=graph)
    tqf.cnot(q_device, wires=[wires[1], wires[0]], static=static_mode, parent_graph=graph)
    RZ[0](q_device, wires=wires[0])
    RZ[1](q_device, wires=wires[1])
    RY[0](q_device, wires=wires[0])
    RY[1](q_device, wires=wires[1])
    RZ[2](q_device, wires=wires[0])
    RZ[3](q_device, wires=wires[1])
    tqf.cnot(q_device, wires=[wires[1], wires[0]], static=static_mode, parent_graph=graph)
    tqf.ry(q_device, wires=wires[1], params=torch.tensor([-np.pi / 2]), static=static_mode, parent_graph=graph)
    tqf.rz(q_device, wires=wires[0], params=torch.tensor([-np.pi / 2]), static=static_mode, parent_graph=graph)
    tqf.rz(q_device, wires=wires[1], params=torch.tensor([-np.pi / 2]), static=static_mode, parent_graph=graph)
##    tqf.qubitunitary(device, wires=wires[1], params=ry_neg_pi)
##    tqf.qubitunitary(device, wires=wires[0], params=rz_neg_pi)
##    tqf.qubitunitary(device, wires=wires[1], params=rz_neg_pi)

class TreeTensorAgent(tq.QuantumModule):
    class QLayer(tq.QuantumModule):
        def __init__(self):
            super().__init__()
            self.n_wires = 8
            self.n_actions = 4
##                self.q_device = tq.QuantumDevice(n_wires=self.n_wires)

            #self.bias = torch.tensor(np.random.rand(4), requires_grad=True)
            self.rz_0_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_0_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_0_0 = tq.RY(has_params=True, trainable=True)
            self.ry_0_1 = tq.RY(has_params=True, trainable=True)
            self.rz_0_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_0_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_1_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_1_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_1_0 = tq.RY(has_params=True, trainable=True)
            self.ry_1_1 = tq.RY(has_params=True, trainable=True)
            self.rz_1_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_1_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_2_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_2_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_2_0 = tq.RY(has_params=True, trainable=True)
            self.ry_2_1 = tq.RY(has_params=True, trainable=True)
            self.rz_2_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_2_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_3_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_3_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_3_0 = tq.RY(has_params=True, trainable=True)
            self.ry_3_1 = tq.RY(has_params=True, trainable=True)
            self.rz_3_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_3_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_4_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_4_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_4_0 = tq.RY(has_params=True, trainable=True)
            self.ry_4_1 = tq.RY(has_params=True, trainable=True)
            self.rz_4_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_4_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_5_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_5_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_5_0 = tq.RY(has_params=True, trainable=True)
            self.ry_5_1 = tq.RY(has_params=True, trainable=True)
            self.rz_5_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_5_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_6_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_6_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_6_0 = tq.RY(has_params=True, trainable=True)
            self.ry_6_1 = tq.RY(has_params=True, trainable=True)
            self.rz_6_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_6_3 = tq.RZ(has_params=True, trainable=True)
            self.cnot = tq.CNOT(has_params=False, trainable=False)

        def forward(self, q_device):
            self.q_device = q_device
            #SO4(self.q_device, [self.ry_0_0, self.ry_0_1], [self.rz_0_0, self.rz_0_1, self.rz_0_2, self.rz_0_3], self.cnot, [0, 1], static=static_mode_mode, parent_graph=graph)
            #Layer 1 Gate 1 Start
            tqf.rz(q_device, wires=0, params=torch.tensor([np.pi / 2]))
            tqf.rz(q_device, wires=1, params=torch.tensor([np.pi / 2]))
            tqf.ry(q_device, wires=1, params=torch.tensor([np.pi / 2]))
            tqf.cnot(q_device, wires=[1, 0])
            self.rz_0_0(q_device, wires=0)
            self.rz_0_1(q_device, wires=1)
            self.ry_0_0(q_device, wires=0)
            self.ry_0_1(q_device, wires=1)
            self.rz_0_2(q_device, wires=0)
            self.rz_0_3(q_device, wires=1)
            tqf.cnot(q_device, wires=[1, 0])
            tqf.ry(q_device, wires=1, params=torch.tensor([-np.pi / 2]))
            tqf.rz(q_device, wires=0, params=torch.tensor([-np.pi / 2]))
            tqf.rz(q_device, wires=1, params=torch.tensor([-np.pi / 2]))
            #Layer 1 Gate 1 End

            #SO4(self.q_device, [self.ry_1_0, self.ry_1_1], [self.rz_1_0, self.rz_1_1, self.rz_1_2, self.rz_1_3], self.cnot, [2, 3], static=static_mode_mode, parent_graph=graph)
            #Layer 1 Gate 2 Start
            tqf.rz(q_device, wires=2, params=torch.tensor([np.pi / 2]))
            tqf.rz(q_device, wires=3, params=torch.tensor([np.pi / 2]))
            tqf.ry(q_device, wires=3, params=torch.tensor([np.pi / 2]))
            tqf.cnot(q_device, wires=[3, 2])
            self.rz_1_0(q_device, wires=2)
            self.rz_1_1(q_device, wires=3)
            self.ry_1_0(q_device, wires=2)
            self.ry_1_1(q_device, wires=3)
            self.rz_1_2(q_device, wires=2)
            self.rz_1_3(q_device, wires=3)
            tqf.cnot(q_device, wires=[3, 2])
            tqf.ry(q_device, wires=3, params=torch.tensor([-np.pi / 2]))
            tqf.rz(q_device, wires=2, params=torch.tensor([-np.pi / 2]))
            tqf.rz(q_device, wires=3, params=torch.tensor([-np.pi / 2]))
            #Layer 1 Gate 2 End

            #SO4(self.q_device, [self.ry_2_0, self.ry_2_1], [self.rz_2_0, self.rz_2_1, self.rz_2_2, self.rz_2_3], self.cnot, [4, 5], static=static_mode_mode, parent_graph=graph)
            #Layer 1 Gate 3 Start
            tqf.rz(q_device, wires=4, params=torch.tensor([np.pi / 2]))
            tqf.rz(q_device, wires=5, params=torch.tensor([np.pi / 2]))
            tqf.ry(q_device, wires=5, params=torch.tensor([np.pi / 2]))
            tqf.cnot(q_device, wires=[5, 4])
            self.rz_2_0(q_device, wires=4)
            self.rz_2_1(q_device, wires=5)
            self.ry_2_0(q_device, wires=4)
            self.ry_2_1(q_device, wires=5)
            self.rz_2_2(q_device, wires=4)
            self.rz_2_3(q_device, wires=5)
            tqf.cnot(q_device, wires=[5, 4])
            tqf.ry(q_device, wires=5, params=torch.tensor([-np.pi / 2]))
            tqf.rz(q_device, wires=4, params=torch.tensor([-np.pi / 2]))
            tqf.rz(q_device, wires=5, params=torch.tensor([-np.pi / 2]))
            #Layer 1 Gate 3 End

            #SO4(self.q_device, [self.ry_3_0, self.ry_3_1], [self.rz_3_0, self.rz_3_1, self.rz_3_2, self.rz_3_3], self.cnot, [6, 7], static=static_mode_mode, parent_graph=graph)
            #Layer 1 Gate 4 Start
            tqf.rz(q_device, wires=6, params=torch.tensor([np.pi / 2]))
            tqf.rz(q_device, wires=7, params=torch.tensor([np.pi / 2]))
            tqf.ry(q_device, wires=7, params=torch.tensor([np.pi / 2]))
            tqf.cnot(q_device, wires=[7, 6])
            self.rz_3_0(q_device, wires=6)
            self.rz_3_1(q_device, wires=7)
            self.ry_3_0(q_device, wires=6)
            self.ry_3_1(q_device, wires=7)
            self.rz_3_2(q_device, wires=6)
            self.rz_3_3(q_device, wires=7)
            tqf.cnot(q_device, wires=[7, 6])
            tqf.ry(q_device, wires=7, params=torch.tensor([-np.pi / 2]))
            tqf.rz(q_device, wires=6, params=torch.tensor([-np.pi / 2]))
            tqf.rz(q_device, wires=7, params=torch.tensor([-np.pi / 2]))
            #Layer 1 Gate 4 End

            #SO4(self.q_device, [self.ry_4_0, self.ry_4_1], [self.rz_4_0, self.rz_4_1, self.rz_4_2, self.rz_4_3], self.cnot, [1, 2], static=static_mode_mode, parent_graph=graph)
            #Layer 2 Gate 1 Start
            tqf.rz(q_device, wires=1, params=torch.tensor([np.pi / 2]))
            tqf.rz(q_device, wires=2, params=torch.tensor([np.pi / 2]))
            tqf.ry(q_device, wires=2, params=torch.tensor([np.pi / 2]))
            tqf.cnot(q_device, wires=[2, 1])
            self.rz_4_0(q_device, wires=1)
            self.rz_4_1(q_device, wires=2)
            self.ry_4_0(q_device, wires=1)
            self.ry_4_1(q_device, wires=2)
            self.rz_4_2(q_device, wires=1)
            self.rz_4_3(q_device, wires=2)
            tqf.cnot(q_device, wires=[2, 1])
            tqf.ry(q_device, wires=2, params=torch.tensor([-np.pi / 2]))
            tqf.rz(q_device, wires=1, params=torch.tensor([-np.pi / 2]))
            tqf.rz(q_device, wires=2, params=torch.tensor([-np.pi / 2]))
            #Layer 2 Gate 1 End

            #SO4(self.q_device, [self.ry_5_0, self.ry_5_1], [self.rz_5_0, self.rz_5_1, self.rz_5_2, self.rz_5_3], self.cnot, [5, 6], static=static_mode_mode, parent_graph=graph)
            #Layer 2 Gate 2 Start
            tqf.rz(q_device, wires=5, params=torch.tensor([np.pi / 2]))
            tqf.rz(q_device, wires=6, params=torch.tensor([np.pi / 2]))
            tqf.ry(q_device, wires=6, params=torch.tensor([np.pi / 2]))
            tqf.cnot(q_device, wires=[6, 5])
            self.rz_5_0(q_device, wires=5)
            self.rz_5_1(q_device, wires=6)
            self.ry_5_0(q_device, wires=5)
            self.ry_5_1(q_device, wires=6)
            self.rz_5_2(q_device, wires=5)
            self.rz_5_3(q_device, wires=6)
            tqf.cnot(q_device, wires=[6, 5])
            tqf.ry(q_device, wires=6, params=torch.tensor([-np.pi / 2]))
            tqf.rz(q_device, wires=5, params=torch.tensor([-np.pi / 2]))
            tqf.rz(q_device, wires=6, params=torch.tensor([-np.pi / 2]))
            #Layer 2 Gate 2 End
            #SO4(self.q_device, [self.ry_6_0, self.ry_6_1], [self.rz_6_0, self.rz_6_1, self.rz_6_2, self.rz_6_3], self.cnot, [2, 5], static=static_mode_mode, parent_graph=graph)
            #Layer 3 Gate 1 Start
            tqf.rz(q_device, wires=2, params=torch.tensor([np.pi / 2]))
            tqf.rz(q_device, wires=5, params=torch.tensor([np.pi / 2]))
            tqf.ry(q_device, wires=5, params=torch.tensor([np.pi / 2]))
            tqf.cnot(q_device, wires=[5, 2])
            self.rz_6_0(q_device, wires=0)
            self.rz_6_1(q_device, wires=1)
            self.ry_6_0(q_device, wires=0)
            self.ry_6_1(q_device, wires=1)
            self.rz_6_2(q_device, wires=0)
            self.rz_6_3(q_device, wires=1)
            tqf.cnot(q_device, wires=[5, 2])
            tqf.ry(q_device, wires=5, params=torch.tensor([-np.pi / 2]))
            tqf.rz(q_device, wires=2, params=torch.tensor([-np.pi / 2]))
            tqf.rz(q_device, wires=5, params=torch.tensor([-np.pi / 2]))
            #Layer 3 Gate 1 End
    def __init__(self, input_size):
        super().__init__()
        self.n_wires = 8
        self.n_actions = 4
        self.input_size = input_size
        self.q_layer = self.QLayer()
        #self.bias = Parameter(torch.zeros(self.n_actions))
        self.smx = nn.Softmax()
        self.bitstrings = gen_bitstrings(self.n_wires)
        self.q_device = tq.QuantumDevice(n_wires=self.n_wires)
        #self.mps = MPS(input_dim = 147, output_dim = 8, bond_dim = 2, feature_dim = 2, use_GPU = False, parallel = True, init_std=1e-2)
##        self.feature_map = nn.Sequential(nn.Linear(self.input_size, 64), nn.ReLU(),
##                                         nn.Conv1d(self.input_size, 64, kernel_size=2, stride=2), nn.ReLU(),
##                                         nn.Conv1d(64, 1, kernel_size=2, stride=3), nn.Tanh())
        self.layer_1 = nn.Linear(1, 256, bias=False)
        #147 x 128
        self.layer_2 = nn.ReLU()
        self.layer_3 = nn.Conv1d(self.input_size, 256, kernel_size=2, padding=1, dilation=2, bias=False)
        self.layer_4 = nn.ReLU()
        self.layer_5 = nn.Conv1d(256, 1, kernel_size=4, stride=8, padding=1, dilation=3, bias=False)
        self.layer_6 = nn.CELU()
##        for param in self.mps.parameters():
##            print("Gradient required: " + str(param.requires_grad))

##        self.measure = tq.MeasureMultiPauliSum(
##            obs_list=[{"wires" : [3, 4],
##                       "observables" : ['z', 'z'],
##                       "coefficient" : [1, 1]}])

##        self.encoder1 = tq.GeneralEncoder(
##            [{"input_idx" : [0], "func" : "ry", "wires" : [0]},
##             {"input_idx" : [1], "func" : "ry", "wires" : [1]},
##             {"input_idx" : [2], "func" : "ry", "wires" : [2]},
##             {"input_idx" : [3], "func" : "ry", "wires" : [3]},
##             {"input_idx" : [4], "func" : "ry", "wires" : [4]},
##             {"input_idx" : [5], "func" : "ry", "wires" : [5]},
##             {"input_idx" : [6], "func" : "ry", "wires" : [6]},
##             {"input_idx" : [7], "func" : "ry", "wires" : [7]}])
##
##        self.encoder2 = tq.GeneralEncoder(
##            [{"input_idx" : [0], "func" : "rz", "wires" : [0]},
##             {"input_idx" : [1], "func" : "rz", "wires" : [1]},
##             {"input_idx" : [2], "func" : "rz", "wires" : [2]},
##             {"input_idx" : [3], "func" : "rz", "wires" : [3]},
##             {"input_idx" : [4], "func" : "rz", "wires" : [4]},
##             {"input_idx" : [5], "func" : "rz", "wires" : [5]},
##             {"input_idx" : [6], "func" : "rz", "wires" : [6]},
##             {"input_idx" : [7], "func" : "rz", "wires" : [7]}])

        self.encoder=tq.GeneralEncoder(
            [{"input_idx" : [0], "func" : "ry", "wires" : [0]},
             {"input_idx" : [1], "func" : "ry", "wires" : [1]},
             {"input_idx" : [2], "func" : "ry", "wires" : [2]},
             {"input_idx" : [3], "func" : "ry", "wires" : [3]},
             {"input_idx" : [4], "func" : "ry", "wires" : [4]},
             {"input_idx" : [5], "func" : "ry", "wires" : [5]},
             {"input_idx" : [6], "func" : "ry", "wires" : [6]},
             {"input_idx" : [7], "func" : "ry", "wires" : [7]},
             {"input_idx" : [8], "func" : "rz", "wires" : [0]},
             {"input_idx" : [9], "func" : "rz", "wires" : [1]},
             {"input_idx" : [10], "func" : "rz", "wires" : [2]},
             {"input_idx" : [11], "func" : "rz", "wires" : [3]},
             {"input_idx" : [12], "func" : "rz", "wires" : [4]},
             {"input_idx" : [13], "func" : "rz", "wires" : [5]},
             {"input_idx" : [14], "func" : "rz", "wires" : [6]},
             {"input_idx" : [15], "func" : "rz", "wires" : [7]},
             {"input_idx" : [16], "func" : "rx", "wires" : [0]},
             {"input_idx" : [17], "func" : "rx", "wires" : [1]},
             {"input_idx" : [18], "func" : "rx", "wires" : [2]},
             {"input_idx" : [19], "func" : "rx", "wires" : [3]},
             {"input_idx" : [20], "func" : "rx", "wires" : [4]},
             {"input_idx" : [21], "func" : "rx", "wires" : [5]},
             {"input_idx" : [22], "func" : "rx", "wires" : [6]},
             {"input_idx" : [23], "func" : "rx", "wires" : [7]},
             {"input_idx" : [24], "func" : "rz", "wires" : [0]},
             {"input_idx" : [25], "func" : "rz", "wires" : [1]},
             {"input_idx" : [26], "func" : "rz", "wires" : [2]},
             {"input_idx" : [27], "func" : "rz", "wires" : [3]},
             {"input_idx" : [28], "func" : "rz", "wires" : [4]},
             {"input_idx" : [29], "func" : "rz", "wires" : [5]},
             {"input_idx" : [30], "func" : "rz", "wires" : [6]},
             {"input_idx" : [31], "func" : "rz", "wires" : [7]}])

##    def get_angles_atan(self, in_x):
##        angles = torch.stack([torch.stack([torch.atan(item), torch.atan(item**2)]) for item in in_x])
##        return angles

    def forward(self, input_data, check=False):
        #measure_counts = np.zeros(self.n_actions)
        prob_dict = {}
        #x = self.feature_map(input_data)
        x_1 = self.layer_1(input_data)
        x_2 = self.layer_2(x_1)
        #print("Stage one size: " + str(x_2.shape))
        x_3 = self.layer_3(x_2)
        x_4 = self.layer_4(x_3)
        #print("Stage two size: " + str(x_3.shape))
        x_5 = self.layer_5(x_4)
        x_6 = self.layer_6(x_5)
        #print("Stage three size " + str(x_6.shape))
        #print(type(x))
        #print(x.shape)
        #x_angles = self.get_angles_atan(x)
        #x_angles = torch.stack([torch.atan(x), torch.atan(x ** 2)])
##        print("Angle array shape: " + str(x_angles.shape))
##        print("Gradient preserved: " + str(x_angles.requires_grad))
        #torch.reshape(x_angles, (1, 16))
##        new_x_angles = x_angles.view(1, 16)
##        x_angles = new_x_angles
##        print("Input shape: " + str(x_angles.shape))
##        print("Gradient preserved: " + str(x_angles.requires_grad))
##        if check:
##            print(x_angles)
        #print(x_angles[0][0])
        x_angles = torch.atan(x_6)
        for i in range(self.n_wires):

            tqf.hadamard(self.q_device, wires=i)

##        self.encoder1(self.q_device, x_angles[0][0])
##        self.encoder2(self.q_device, x_angles[1][0])
        self.encoder(self.q_device, x_angles)
        #print("Parent graph: " + str(self.graph))
        self.q_layer.forward(self.q_device)

##        SO4(self.q_device, [self.ry_0_0, self.ry_0_1], [self.rz_0_0, self.rz_0_1, self.rz_0_2, self.rz_0_3], self.cnot, [0, 1])
##        SO4(self.q_device, [self.ry_1_0, self.ry_1_1], [self.rz_1_0, self.rz_1_1, self.rz_1_2, self.rz_1_3], self.cnot, [2, 3])
##        SO4(self.q_device, [self.ry_2_0, self.ry_2_1], [self.rz_2_0, self.rz_2_1, self.rz_2_2, self.rz_2_3], self.cnot, [4, 5])
##        SO4(self.q_device, [self.ry_3_0, self.ry_3_1], [self.rz_3_0, self.rz_3_1, self.rz_3_2, self.rz_3_3], self.cnot, [6, 7])
##        SO4(self.q_device, [self.ry_4_0, self.ry_4_1], [self.rz_4_0, self.rz_4_1, self.rz_4_2, self.rz_4_3], self.cnot, [1, 2])
##        SO4(self.q_device, [self.ry_5_0, self.ry_5_1], [self.rz_5_0, self.rz_5_1, self.rz_5_2, self.rz_5_3], self.cnot, [5, 6])
##        SO4(self.q_device, [self.ry_6_0, self.ry_6_1], [self.rz_6_0, self.rz_6_1, self.rz_6_2, self.rz_6_3], self.cnot, [2, 5])
        #print("ops done")
##        device_states = self.q_device.get_states_1d()
##        #print(device_states)
##        circuit_state = tq.QuantumState(n_wires=self.n_wires)
##        circuit_state.set_states(device_states)
        #state_vec = self.q_device.get_states_1d().abs().detach().cpu().numpy()
##        print("State vector: ")
##        print(state_vec)
##        print(state_vec.shape)
        #measures = tq.measure(self.q_device, n_shots=4096)
##        for i in range(len(self.bitstrings)):
##            prob_dict[self.bitstrings[i]] = np.abs(state_vec[0][i]) ** 2
##        qbit_states = list(prob_dict.keys())
        #print(qbit_states)
##        print(type(measure_results))
        #print(measure_results)
        #qbit_states = [result.keys() for result in measure_results]
        #print(qbit_states)
##        for bitkey in qbit_states:
##            if bitkey[3] == '0' and bitkey[4] == '0':
##                measure_counts[0] += prob_dict[bitkey]
##            elif bitkey[3] == '0' and bitkey[4] == '1':
##                measure_counts[1] += prob_dict[bitkey]
##            elif bitkey[3] == '1' and bitkey[4] == '0':
##                measure_counts[2] += prob_dict[bitkey]
##            else:
##                measure_counts[3] += prob_dict[bitkey]
##        measure_norm = np.linalg.norm(measure_counts)
##        measure_counts = measure_counts / measure_norm
##        if check:
##            print("Measure outcomes: ")
##            print(measure_counts)
        #measure_weights = torch.tensor(measure_counts, requires_grad=True)
        obs_1 = expval_joint_analytical(self.q_device, "ZZZXXZZZ")
        obs_2 = expval_joint_analytical(self.q_device, "ZZZYYZZZ")
        obs_3 = expval_joint_analytical(self.q_device, "ZZZYXZZZ")
        obs_4 = expval_joint_analytical(self.q_device, "ZZZXYZZZ")
        expectations = torch.stack([obs_1, obs_2, obs_3, obs_4], dim=1)
        #measure_weights = self.smx(measure_results)
        if check:
            print("Measure weights: ")
            print(expectations)
        #print("Gradient preserved: " + str(measure_weights.requires_grad))
        measure_weights = expectations.view(4)
        #print("Output shape: " + str(measure_weights.shape))
##        if check:
##            print(measure_weights)
        #print("Measure results")
        #print(measure_counts)
        return measure_weights

def square_loss(labels, predictions):
    loss = 0
    for l, p in zip(labels, predictions):
##        print(type(l))
##        print(type(p))
        loss = loss + ((l - p) ** 2)
    loss = loss / len(labels)
    #print(type(loss))
    return loss

def epsilon_greedy(TreeTensor, epsilon, s, n_actions, timestep, rgen, check=False, train=False):
##    seed = int(time.time())
##    rng = np.random.default_rng(seed)
    if train or rgen.random() < ((epsilon / n_actions) + (1 - epsilon)):
        with torch.no_grad():
            measurements = TreeTensor(s, check=check)
            action = torch.argmax(measurements)
            if check:
                print("Argmax result: " + str(action))
            return action

        #print("Circuit")
    else:
##        seedval = int(time.time())
##        np.random.seed(seedVal)
        action = rgen.integers(0, high=n_actions)
##        if check:
##            print(choices)
        #action = np.bincount(choices).argmax()
##        if check:
##            print(action)

        print("Epsilon")
        action = torch.tensor(action)
        return action

def cost(model, features, labels, dev):
    #print(features)
    loss_func = nn.SmoothL1Loss()
    predictions = [model(item.state)[item.action] for item in features]
    loss_total = loss_func(torch.tensor(labels, requires_grad=True, device=dev), torch.tensor(predictions, requires_grad=True, device=dev))
    return loss_total

def ttn_train(env_name, model, alpha, gamma, epsilon, episodes, max_steps, n_actions, top_dev, opt, sched, render=True):

    act_range = [0, 1, 2, 6]
    logging.basicConfig(filename="ExperimentDebug1.txt", level=logging.DEBUG)
    logging.captureWarnings(True)
##    use_cuda = torch.cuda.is_available()
##    main_device = torch.device("cuda" if use_cuda else "cpu")
##    circuit = model.to(main_device)

    param_file = "TTN_params.bin"
    scores = []
    target_update = 20
    batch_size = 100
    optimize_steps = 5
    target_update_counter = 0
    iter_index = []
    iter_reward = []
    iter_total_steps = []
    cost_list = []
    timestep_reward = []
    random.seed(int(time.time()))
    seed = int(time.time())
    rng = np.random.default_rng(seed)
    memory = ReplayMemory(500)
##    seedVal = int(time.time())
##    np.random.seed(seedVal)
    #q_device = tq.QuantumDevice(n_wires=8)
    #print(type(q_device))
    #optimizer = optim.Adam(model.parameters(), lr=alpha, weight_decay=1e-4)
    #optimizer_mps = optim.Adam(model.mps.parameters(), lr=alpha, weight_decay=1e-4)
##    optimizer = optim.SGD(model.parameters(), lr=alpha, momentum=0.9)
##    optimizer_mps = optim.SGD(model.parameters(), lr=alpha, momentum=0.9)
    #scheduler = CosineAnnealingLR(optimizer, T_max=episodes)
    #scheduler_mps = CosineAnnealingLR(optimizer_mps, T_max=episodes)
    env = gym.make(env_name, max_episode_steps=max_steps, disable_env_checker=True, render_mode="human")
    #print(type(env))
    env = ImgObsWrapper(env)
    env_record = RecordVideo(env, f"video/TTNMinigridTraining")
##    record_dict = env_record.__dict__
##    record_keys = record_dict.keys()
##    print(record_keys)
    start_state = None
    start_time = time.asctime()
    for episode in range(episodes):
        env_record.reset()
        print("Episode: " + str(episode))
        t = 0
        total_reward = 0
        #print("Reset reward")
        done = False
        #print("Not done")
        #rgen = np.random.RandomState(seedVal)
        seedVal = int(time.time())
        np.random.seed(seedVal)
##        print(type(observation))
##        print(observation)
        #print("Reset complete")
        if episode == 0:
            start_state = env_record.env.grid
            #print("Start state set")
        else:
            env_record.env.grid = start_state
            #print("Start state retrieved")
        if render:
            env_record.render()
        #print("Number of obstacles: " + str(len(env_record.env.obstacles)))
        observation = env_record.env.gen_obs()
        #print("Got observation")
        #print(type(observation))
#        print(observation)
        observation = torch.tensor(observation['image']).type('torch.FloatTensor').view(147, 1).to(top_dev)
        print("Observation shape: " + str(observation.shape))
        #print("Observation formatted")
        #print(observation)
        observation.requires_grad = True
        #observation = observation.to(top_dev)
        act = epsilon_greedy(model, epsilon, observation, n_actions, t, rng, check=True)
        #print("Action index: " + str(act))
##        print("Got action")
        action = act_range[act]
        #print("Action selection: " + str(action))
        #scores.append(total_reward)
        while t < max_steps:
            print("Episode: " + str(episode) + " , " + "Timestep: " + str(t))
            if render:
                env_record.render()
            #print("Time Step: " + str(t))
            t += 1
            target_update_counter += 1
            seedVal = int(time.time())
            np.random.seed(seedVal)
            next_obs, reward, done, _, info = env_record.step(action)
            print("Step reward: " + str(reward))
            #print("Step reward: " + str(reward))
            #print(type(next_obs))
            next_obs = torch.tensor(next_obs).type('torch.FloatTensor').view(147, 1).to(top_dev)
            next_obs.requires_grad = True
            #new_obs = next_obs.to(main_device)
            total_reward += reward
            act_ = epsilon_greedy(model, epsilon, next_obs, n_actions, t, rng, check=True)
            #print("Action index: " + str(act))
            action_ = act_range[act_]
            #print("Action selection: " + str(action))
            memory.push(observation, act, reward, next_obs, done)
            if len(memory) > batch_size and done:
                batch_sampled = memory.sample(batch_size)
##                batch = Transition(*zip(*transitions))
##                non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool)
##                non_final_next_states = torch.cat([s for s in batch.next_state
##                                                if s is not None])
##                state_batch = torch.cat(batch.state)
##                action_batch = torch.cat(batch.action)
##                reward_batch = torch.cat(batch.reward)
                Qtarget = [item.reward + (1 - int(item.done)) * gamma * torch.max(model(item.next_state)) for item in batch_sampled]
                loss = cost(model, batch_sampled, Qtarget, top_dev)
                #print("Loss Gradient Function: " + str(loss.grad_fn))
##                grads = torch.autograd.grad(loss, list(model.parameters()), allow_unused=True)
##                print("Grad type: " + str(grads))
                #print("Loss: " + str(loss))

                #optimizer_mps.zero_grad()
                opt.zero_grad()
                    #print(loss)
                    #print(type(loss))
##                loss.backward()
##                for param in model.parameters():
##                    print("Parameter gradient: " + str(param.grad))
                opt.step()
                #optimizer_mps.step()
                #print(model.parameters())
                #print("Optimization step")
            #scheduler.step()
            #scheduler_mps.step()
##            current_replay_memory = memory.output_all()
##            current_target_for_replay_memory = [item.reward + (1 - int(item.done)) * gamma * torch.max(model(item.next_state)) for item in current_replay_memory]
            if target_update_counter >= target_update:
                target_update_counter = 0
            observation, action = next_obs, action_
            if done or t == max_steps:
                epsilon = epsilon / ((episode / 300) + 1)
                alpha = 0.95 * alpha
                timestep_reward.append(total_reward)
                print("Reward data length: " + str(len(timestep_reward)))
                iter_index.append(episode)
                iter_total_steps.append(t)
                break
    stop_time = time.asctime()
    print("Start time: ")
    print(start_time)
    print("Stop time: ")
    print(stop_time)
    torch.save(model.state_dict(), param_file)
    return timestep_reward, iter_index, iter_reward, iter_total_steps

def test_agent(model, env_folder, epsilon, env_name, config_name, n_tests, max_steps, delay=1):
    act_range = [0, 1, 2, 6]
    n_successes = 0
    test_rewards = []
##    use_cuda = torch.cuda.is_available()
##    main_device = torch.device("cuda" if use_cuda else "cpu")
##    circuit = model.to(main_device)
    env = gym.make(env_name, max_episode_steps=max_steps, render_mode="human", height=64, width=64)
    env = SymbolicObsWrapper(env)
    env = ImgObsFlatWrapper(env)
    env_record = RecordVideo(env, f"video/TTNMinigridTraining")
    done = False
    for test in range(n_tests):
        reward_total = 0
        epsilon = 0
        env.reset()
        filename = env_folder + "/" + env_name + "_" + str(test)
        statefile = open(filename, "wb")
        state_data = pkl.load(statefile)
        new_grid = env_record.env.grid.decode(state_data)
        env_record.env.grid = new_grid
        while True:
            time.sleep(delay)
            s = torch.tensor(observation).type('torch.FloatTensor').view(1, -1)
            act = epsilon_greedy(model, epsilon, observation)
            a = act_range[act]
            next_obs, reward, done, info = env_record.step(a)
            next_obs = torch.tensor(next_obs).type('torch.FloatTensor').view(1, -1)
            reward_total += reward
            if done:
                if reward > 0:
                    n_successes += 1
                    print("Goal Reached")
                else:
                    print("Task Failed")
                test_rewards.append(reward_total)
                time.sleep(3)
                break
    return test_rewards, n_sucesses

def main():
##    register(
##        id="Minigrid-RandomLava-6Spots-v0",
##        entry_point="RandomLavaMinigrid:RandomLavaEnv",
##        kwargs={"size": 8, "n_obstacles": 6})
    #env_name = "Minigrid-RandomLava-6Spots-v0"
    env_name = "MiniGrid-Empty-8x8-v0"
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    torch.use_deterministic_algorithms(True)
    torch.backends.cudnn.benchmark = False
    #device = torch.device("cpu")
    alpha = 0.4
    gamma = 0.5
    epsilon = 1
    episodes = 1000
    max_steps = 100
    n_actions = 4
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    model = TreeTensorAgent(147).to(device)
    print("Model graph: " + str(model.graph))

    tn_opt = optim.Adam(model.parameters(), lr=5e-3, weight_decay=1e-4)
    scheduler = CosineAnnealingLR(tn_opt, T_max=episodes)
    timestep_reward, iter_index, iter_reward, iter_total_steps = ttn_train(env_name, model, alpha, gamma, epsilon, episodes, max_steps, n_actions, device, tn_opt, scheduler)
    x_vals = np.arange(episodes)
    y_vals = np.asarray(timestep_reward)
    fig, ax = plt.subplots()
    ax.plot(x_vals, y_vals)
    ax.grid()
    ax.set(xlabel="Episode", ylabel="Total Score", title="Deep Quantum TTN Learning Training Process: 6-Site Random Lava Minigrid")
    fig.savefig("TTNTrain.png")
    plt.close(fig)

if __name__ == "__main__":
    main()