mit-han-lab / torchquantum

A PyTorch-based framework for Quantum Classical Simulation, Quantum Machine Learning, Quantum Neural Networks, Parameterized Quantum Circuits with support for easy deployments on real quantum computers.
MIT License
1.24k stars 185 forks source link

Torchquantum reinforcement learning agent behaves randomly even after epsilon-greedy phase #124

Open JustinS6626 opened 1 year ago

JustinS6626 commented 1 year ago

First of all, thanks again so much for your help with the parameter updating problem that I reported a few weeks ago. I have now encountered another issue. What seems to be happening is that even after the epsilon-greedy phase is over, the agent still behaves randomly. Since I deactivated randomness for the linear and convolutional layers, I am wondering what the source of the problem is. My code as it is right now is provided below:

import math
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn as nn
import torch.nn.functional as F
import time
import datetime
import calendar
import random
from minigrid.wrappers import *
import logging
from torchpack.callbacks import (InferenceRunner, MeanAbsoluteError,
                                 MaxSaver, MinSaver,
                                 Saver, SaverRestore, CategoricalAccuracy)
from torchpack.environ import set_run_dir
from torchpack.utils.config import configs
from torchpack.utils.logging import logger
from torchtest import assert_vars_change
from torch.nn.parameter import Parameter
import torchquantum as tq
import torchquantum.functional as tqf
from torchquantum.measurement import *
import matplotlib.pyplot as plt
import pickle as pkl
from obs_wrappers import ImgObsFlatWrapper
import gymnasium as gym
from gymnasium.wrappers.record_video import RecordVideo
from collections import namedtuple, deque
#from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistic
import numpy as np
import os
from gymnasium.envs.registration import *
##from ._utils import _import_dotted_name
##from ._six import string_classes as _string_classes
##from torch._sources import get_source_lines_and_file
##from torch.types import Storage
##from import _get_dtype_from_pickle_storage_type
##from typing_extensions import TypeAlias
##import copyreg

class ReplayMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        if len(self.memory) < self.capacity:
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def output_all(self):
        return self.memory

    def __len__(self):
        return len(self.memory)

Transition = namedtuple('Transition',
                        ('state', 'action', 'reward', 'next_state', 'done'))

def SO4(q_device, RY, RZ, CNOT, wires, static=None, parent_graph=None):
##    rz_pi = np.asarray([[np.exp(-1j * (np.pi / 4)), 0],
##                        [0, np.exp(1j * (np.pi / 4))]])
##    rz_neg_pi = np.asarray([[np.exp(1j * (np.pi / 4)), 0],
##                        [0, np.exp(-1j * (np.pi / 4))]])
##    ry_pi = np.asarray([[np.cos(np.pi / 4), -1 * np.sin(np.pi / 4)],
##                        [np.sin(np.pi / 4), np.cos(np.pi / 4)]])
##    ry_neg_pi = np.asarray([[np.cos(-1 * np.pi / 4), -1 * np.sin(-1 * np.pi / 4)],
##                            [np.sin(-1 * np.pi / 4), np.cos(-1 * np.pi / 4)]])
##    tqf.qubitunitary(device, wires=wires[0], params=rz_pi)
##    tqf.qubitunitary(device, wires=wires[1], params=rz_pi)
##    tqf.qubitunitary(device, wires=wires[1], params=ry_pi)
    tqf.rz(q_device, wires=wires[0], params=torch.tensor([np.pi / 2]), static=static_mode, parent_graph=graph)
    tqf.rz(q_device, wires=wires[1], params=torch.tensor([np.pi / 2]), static=static_mode, parent_graph=graph)
    tqf.ry(q_device, wires=wires[1], params=torch.tensor([np.pi / 2]), static=static_mode, parent_graph=graph)
    tqf.cnot(q_device, wires=[wires[1], wires[0]], static=static_mode, parent_graph=graph)
    RZ[0](q_device, wires=wires[0])
    RZ[1](q_device, wires=wires[1])
    RY[0](q_device, wires=wires[0])
    RY[1](q_device, wires=wires[1])
    RZ[2](q_device, wires=wires[0])
    RZ[3](q_device, wires=wires[1])
    tqf.cnot(q_device, wires=[wires[1], wires[0]], static=static_mode, parent_graph=graph)
    tqf.ry(q_device, wires=wires[1], params=torch.tensor([-np.pi / 2]), static=static_mode, parent_graph=graph)
    tqf.rz(q_device, wires=wires[0], params=torch.tensor([-np.pi / 2]), static=static_mode, parent_graph=graph)
    tqf.rz(q_device, wires=wires[1], params=torch.tensor([-np.pi / 2]), static=static_mode, parent_graph=graph)
##    tqf.qubitunitary(device, wires=wires[1], params=ry_neg_pi)
##    tqf.qubitunitary(device, wires=wires[0], params=rz_neg_pi)
##    tqf.qubitunitary(device, wires=wires[1], params=rz_neg_pi)

class TreeTensorAgent(tq.QuantumModule):
    class QLayer(tq.QuantumModule):
        def __init__(self):
            self.n_wires = 8
            self.n_actions = 4
##                self.q_device = tq.QuantumDevice(n_wires=self.n_wires)

            #self.bias = torch.tensor(np.random.rand(4), requires_grad=True)
            self.rz_0_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_0_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_0_0 = tq.RY(has_params=True, trainable=True)
            self.ry_0_1 = tq.RY(has_params=True, trainable=True)
            self.rz_0_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_0_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_1_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_1_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_1_0 = tq.RY(has_params=True, trainable=True)
            self.ry_1_1 = tq.RY(has_params=True, trainable=True)
            self.rz_1_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_1_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_2_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_2_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_2_0 = tq.RY(has_params=True, trainable=True)
            self.ry_2_1 = tq.RY(has_params=True, trainable=True)
            self.rz_2_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_2_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_3_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_3_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_3_0 = tq.RY(has_params=True, trainable=True)
            self.ry_3_1 = tq.RY(has_params=True, trainable=True)
            self.rz_3_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_3_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_4_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_4_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_4_0 = tq.RY(has_params=True, trainable=True)
            self.ry_4_1 = tq.RY(has_params=True, trainable=True)
            self.rz_4_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_4_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_5_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_5_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_5_0 = tq.RY(has_params=True, trainable=True)
            self.ry_5_1 = tq.RY(has_params=True, trainable=True)
            self.rz_5_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_5_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_6_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_6_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_6_0 = tq.RY(has_params=True, trainable=True)
            self.ry_6_1 = tq.RY(has_params=True, trainable=True)
            self.rz_6_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_6_3 = tq.RZ(has_params=True, trainable=True)
            self.cnot = tq.CNOT(has_params=False, trainable=False)

        def forward(self, q_device, static_mode, graph):
            self.q_device = q_device
            #SO4(self.q_device, [self.ry_0_0, self.ry_0_1], [self.rz_0_0, self.rz_0_1, self.rz_0_2, self.rz_0_3], self.cnot, [0, 1], static=static_mode_mode, parent_graph=graph)
            #Layer 1 Gate 1 Start
            tqf.rz(q_device, wires=0, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=1, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.ry(q_device, wires=1, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.cnot(q_device, wires=[1, 0], static=static_mode)
            self.rz_0_0(q_device, wires=0)
            self.rz_0_1(q_device, wires=1)
            self.ry_0_0(q_device, wires=0)
            self.ry_0_1(q_device, wires=1)
            self.rz_0_2(q_device, wires=0)
            self.rz_0_3(q_device, wires=1)
            tqf.cnot(q_device, wires=[1, 0], static=static_mode)
            tqf.ry(q_device, wires=1, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=0, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=1, params=torch.tensor([-np.pi / 2]), static=static_mode)
            #Layer 1 Gate 1 End

            #SO4(self.q_device, [self.ry_1_0, self.ry_1_1], [self.rz_1_0, self.rz_1_1, self.rz_1_2, self.rz_1_3], self.cnot, [2, 3], static=static_mode_mode, parent_graph=graph)
            #Layer 1 Gate 2 Start
            tqf.rz(q_device, wires=2, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=3, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.ry(q_device, wires=3, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.cnot(q_device, wires=[3, 2], static=static_mode)
            self.rz_1_0(q_device, wires=2)
            self.rz_1_1(q_device, wires=3)
            self.ry_1_0(q_device, wires=2)
            self.ry_1_1(q_device, wires=3)
            self.rz_1_2(q_device, wires=2)
            self.rz_1_3(q_device, wires=3)
            tqf.cnot(q_device, wires=[3, 2], static=static_mode)
            tqf.ry(q_device, wires=3, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=2, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=3, params=torch.tensor([-np.pi / 2]), static=static_mode)
            #Layer 1 Gate 2 End

            #SO4(self.q_device, [self.ry_2_0, self.ry_2_1], [self.rz_2_0, self.rz_2_1, self.rz_2_2, self.rz_2_3], self.cnot, [4, 5], static=static_mode_mode, parent_graph=graph)
            #Layer 1 Gate 3 Start
            tqf.rz(q_device, wires=4, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=5, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.ry(q_device, wires=5, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.cnot(q_device, wires=[5, 4], static=static_mode)
            self.rz_2_0(q_device, wires=4)
            self.rz_2_1(q_device, wires=5)
            self.ry_2_0(q_device, wires=4)
            self.ry_2_1(q_device, wires=5)
            self.rz_2_2(q_device, wires=4)
            self.rz_2_3(q_device, wires=5)
            tqf.cnot(q_device, wires=[5, 4], static=static_mode)
            tqf.ry(q_device, wires=5, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=4, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=5, params=torch.tensor([-np.pi / 2]), static=static_mode)
            #Layer 1 Gate 3 End

            #SO4(self.q_device, [self.ry_3_0, self.ry_3_1], [self.rz_3_0, self.rz_3_1, self.rz_3_2, self.rz_3_3], self.cnot, [6, 7], static=static_mode_mode, parent_graph=graph)
            #Layer 1 Gate 4 Start
            tqf.rz(q_device, wires=6, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=7, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.ry(q_device, wires=7, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.cnot(q_device, wires=[7, 6], static=static_mode)
            self.rz_3_0(q_device, wires=6)
            self.rz_3_1(q_device, wires=7)
            self.ry_3_0(q_device, wires=6)
            self.ry_3_1(q_device, wires=7)
            self.rz_3_2(q_device, wires=6)
            self.rz_3_3(q_device, wires=7)
            tqf.cnot(q_device, wires=[7, 6], static=static_mode)
            tqf.ry(q_device, wires=7, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=6, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=7, params=torch.tensor([-np.pi / 2]), static=static_mode)
            #Layer 1 Gate 4 End

            #SO4(self.q_device, [self.ry_4_0, self.ry_4_1], [self.rz_4_0, self.rz_4_1, self.rz_4_2, self.rz_4_3], self.cnot, [1, 2], static=static_mode_mode, parent_graph=graph)
            #Layer 2 Gate 1 Start
            tqf.rz(q_device, wires=1, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=2, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.ry(q_device, wires=2, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.cnot(q_device, wires=[2, 1], static=static_mode)
            self.rz_4_0(q_device, wires=1)
            self.rz_4_1(q_device, wires=2)
            self.ry_4_0(q_device, wires=1)
            self.ry_4_1(q_device, wires=2)
            self.rz_4_2(q_device, wires=1)
            self.rz_4_3(q_device, wires=2)
            tqf.cnot(q_device, wires=[2, 1], static=static_mode)
            tqf.ry(q_device, wires=2, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=1, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=2, params=torch.tensor([-np.pi / 2]), static=static_mode)
            #Layer 2 Gate 1 End

            #SO4(self.q_device, [self.ry_5_0, self.ry_5_1], [self.rz_5_0, self.rz_5_1, self.rz_5_2, self.rz_5_3], self.cnot, [5, 6], static=static_mode_mode, parent_graph=graph)
            #Layer 2 Gate 2 Start
            tqf.rz(q_device, wires=5, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=6, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.ry(q_device, wires=6, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.cnot(q_device, wires=[6, 5], static=static_mode)
            self.rz_5_0(q_device, wires=5)
            self.rz_5_1(q_device, wires=6)
            self.ry_5_0(q_device, wires=5)
            self.ry_5_1(q_device, wires=6)
            self.rz_5_2(q_device, wires=5)
            self.rz_5_3(q_device, wires=6)
            tqf.cnot(q_device, wires=[6, 5], static=static_mode)
            tqf.ry(q_device, wires=6, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=5, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=6, params=torch.tensor([-np.pi / 2]), static=static_mode)
            #Layer 2 Gate 2 End
            #SO4(self.q_device, [self.ry_6_0, self.ry_6_1], [self.rz_6_0, self.rz_6_1, self.rz_6_2, self.rz_6_3], self.cnot, [2, 5], static=static_mode_mode, parent_graph=graph)
            #Layer 3 Gate 1 Start
            tqf.rz(q_device, wires=2, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=5, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.ry(q_device, wires=5, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.cnot(q_device, wires=[5, 2], static=static_mode)
            self.rz_6_0(q_device, wires=0)
            self.rz_6_1(q_device, wires=1)
            self.ry_6_0(q_device, wires=0)
            self.ry_6_1(q_device, wires=1)
            self.rz_6_2(q_device, wires=0)
            self.rz_6_3(q_device, wires=1)
            tqf.cnot(q_device, wires=[5, 2], static=static_mode, parent_graph=graph)
            tqf.ry(q_device, wires=5, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=2, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=5, params=torch.tensor([-np.pi / 2]), static=static_mode)
            #Layer 3 Gate 1 End
    def __init__(self, input_size):
        self.n_wires = 8
        self.n_actions = 4
        self.input_size = input_size
        self.q_layer = self.QLayer()
        #self.bias = Parameter(torch.zeros(self.n_actions))
        self.smx = nn.Softmax()
        self.bitstrings = gen_bitstrings(self.n_wires)
        self.q_device = tq.QuantumDevice(n_wires=self.n_wires)
        #self.mps = MPS(input_dim = 147, output_dim = 8, bond_dim = 2, feature_dim = 2, use_GPU = False, parallel = True, init_std=1e-2)
##        self.feature_map = nn.Sequential(nn.Linear(self.input_size, 64), nn.ReLU(),
##                                         nn.Conv1d(self.input_size, 64, kernel_size=2, stride=2), nn.ReLU(),
##                                         nn.Conv1d(64, 1, kernel_size=2, stride=3), nn.Tanh())
        self.layer_1 = nn.Linear(1,256)
        #147 x 128
        self.layer_2 = nn.ReLU()
        self.layer_3 = nn.Conv1d(self.input_size, 256, kernel_size=2, padding=1, dilation=2)
        self.layer_4 = nn.ReLU()
        self.layer_5 = nn.Conv1d(256, 1, kernel_size=4, stride=8, padding=1, dilation=3)
        self.layer_6 = nn.CELU()
##        for param in self.mps.parameters():
##            print("Gradient required: " + str(param.requires_grad))

##        self.measure = tq.MeasureMultiPauliSum(
##            obs_list=[{"wires" : [3, 4],
##                       "observables" : ['z', 'z'],
##                       "coefficient" : [1, 1]}])

##        self.encoder1 = tq.GeneralEncoder(
##            [{"input_idx" : [0], "func" : "ry", "wires" : [0]},
##             {"input_idx" : [1], "func" : "ry", "wires" : [1]},
##             {"input_idx" : [2], "func" : "ry", "wires" : [2]},
##             {"input_idx" : [3], "func" : "ry", "wires" : [3]},
##             {"input_idx" : [4], "func" : "ry", "wires" : [4]},
##             {"input_idx" : [5], "func" : "ry", "wires" : [5]},
##             {"input_idx" : [6], "func" : "ry", "wires" : [6]},
##             {"input_idx" : [7], "func" : "ry", "wires" : [7]}])
##        self.encoder2 = tq.GeneralEncoder(
##            [{"input_idx" : [0], "func" : "rz", "wires" : [0]},
##             {"input_idx" : [1], "func" : "rz", "wires" : [1]},
##             {"input_idx" : [2], "func" : "rz", "wires" : [2]},
##             {"input_idx" : [3], "func" : "rz", "wires" : [3]},
##             {"input_idx" : [4], "func" : "rz", "wires" : [4]},
##             {"input_idx" : [5], "func" : "rz", "wires" : [5]},
##             {"input_idx" : [6], "func" : "rz", "wires" : [6]},
##             {"input_idx" : [7], "func" : "rz", "wires" : [7]}])

            [{"input_idx" : [0], "func" : "ry", "wires" : [0]},
             {"input_idx" : [1], "func" : "ry", "wires" : [1]},
             {"input_idx" : [2], "func" : "ry", "wires" : [2]},
             {"input_idx" : [3], "func" : "ry", "wires" : [3]},
             {"input_idx" : [4], "func" : "ry", "wires" : [4]},
             {"input_idx" : [5], "func" : "ry", "wires" : [5]},
             {"input_idx" : [6], "func" : "ry", "wires" : [6]},
             {"input_idx" : [7], "func" : "ry", "wires" : [7]},
             {"input_idx" : [8], "func" : "rz", "wires" : [0]},
             {"input_idx" : [9], "func" : "rz", "wires" : [1]},
             {"input_idx" : [10], "func" : "rz", "wires" : [2]},
             {"input_idx" : [11], "func" : "rz", "wires" : [3]},
             {"input_idx" : [12], "func" : "rz", "wires" : [4]},
             {"input_idx" : [13], "func" : "rz", "wires" : [5]},
             {"input_idx" : [14], "func" : "rz", "wires" : [6]},
             {"input_idx" : [15], "func" : "rz", "wires" : [7]},
             {"input_idx" : [16], "func" : "rx", "wires" : [0]},
             {"input_idx" : [17], "func" : "rx", "wires" : [1]},
             {"input_idx" : [18], "func" : "rx", "wires" : [2]},
             {"input_idx" : [19], "func" : "rx", "wires" : [3]},
             {"input_idx" : [20], "func" : "rx", "wires" : [4]},
             {"input_idx" : [21], "func" : "rx", "wires" : [5]},
             {"input_idx" : [22], "func" : "rx", "wires" : [6]},
             {"input_idx" : [23], "func" : "rx", "wires" : [7]},
             {"input_idx" : [24], "func" : "rz", "wires" : [0]},
             {"input_idx" : [25], "func" : "rz", "wires" : [1]},
             {"input_idx" : [26], "func" : "rz", "wires" : [2]},
             {"input_idx" : [27], "func" : "rz", "wires" : [3]},
             {"input_idx" : [28], "func" : "rz", "wires" : [4]},
             {"input_idx" : [29], "func" : "rz", "wires" : [5]},
             {"input_idx" : [30], "func" : "rz", "wires" : [6]},
             {"input_idx" : [31], "func" : "rz", "wires" : [7]}])

##    def get_angles_atan(self, in_x):
##        angles = torch.stack([torch.stack([torch.atan(item), torch.atan(item**2)]) for item in in_x])
##        return angles

    def forward(self, input_data, check=False):
        #measure_counts = np.zeros(self.n_actions)
        prob_dict = {}
        #x = self.feature_map(input_data)
        x_1 = self.layer_1(input_data)
        x_2 = self.layer_2(x_1)
        #print("Stage one size: " + str(x_2.shape))
        x_3 = self.layer_3(x_2)
        x_4 = self.layer_4(x_3)
        #print("Stage two size: " + str(x_3.shape))
        x_5 = self.layer_5(x_4)
        x_6 = self.layer_6(x_5)
        #print("Stage three size " + str(x_6.shape))
        #x_angles = self.get_angles_atan(x)
        #x_angles = torch.stack([torch.atan(x), torch.atan(x ** 2)])
##        print("Angle array shape: " + str(x_angles.shape))
##        print("Gradient preserved: " + str(x_angles.requires_grad))
        #torch.reshape(x_angles, (1, 16))
##        new_x_angles = x_angles.view(1, 16)
##        x_angles = new_x_angles
##        print("Input shape: " + str(x_angles.shape))
##        print("Gradient preserved: " + str(x_angles.requires_grad))
##        if check:
##            print(x_angles)
        x_angles = torch.atan(x_6)
        for i in range(self.n_wires):

            tqf.hadamard(self.q_device, wires=i, static=self.static_mode, parent_graph=self.graph)

##        self.encoder1(self.q_device, x_angles[0][0])
##        self.encoder2(self.q_device, x_angles[1][0])
        self.encoder(self.q_device, x_angles)
        #print("Parent graph: " + str(self.graph))
        self.q_layer.forward(self.q_device, self.static_mode, self.graph)

##        SO4(self.q_device, [self.ry_0_0, self.ry_0_1], [self.rz_0_0, self.rz_0_1, self.rz_0_2, self.rz_0_3], self.cnot, [0, 1])
##        SO4(self.q_device, [self.ry_1_0, self.ry_1_1], [self.rz_1_0, self.rz_1_1, self.rz_1_2, self.rz_1_3], self.cnot, [2, 3])
##        SO4(self.q_device, [self.ry_2_0, self.ry_2_1], [self.rz_2_0, self.rz_2_1, self.rz_2_2, self.rz_2_3], self.cnot, [4, 5])
##        SO4(self.q_device, [self.ry_3_0, self.ry_3_1], [self.rz_3_0, self.rz_3_1, self.rz_3_2, self.rz_3_3], self.cnot, [6, 7])
##        SO4(self.q_device, [self.ry_4_0, self.ry_4_1], [self.rz_4_0, self.rz_4_1, self.rz_4_2, self.rz_4_3], self.cnot, [1, 2])
##        SO4(self.q_device, [self.ry_5_0, self.ry_5_1], [self.rz_5_0, self.rz_5_1, self.rz_5_2, self.rz_5_3], self.cnot, [5, 6])
##        SO4(self.q_device, [self.ry_6_0, self.ry_6_1], [self.rz_6_0, self.rz_6_1, self.rz_6_2, self.rz_6_3], self.cnot, [2, 5])
        #print("ops done")
##        device_states = self.q_device.get_states_1d()
##        #print(device_states)
##        circuit_state = tq.QuantumState(n_wires=self.n_wires)
##        circuit_state.set_states(device_states)
        #state_vec = self.q_device.get_states_1d().abs().detach().cpu().numpy()
##        print("State vector: ")
##        print(state_vec)
##        print(state_vec.shape)
        #measures = tq.measure(self.q_device, n_shots=4096)
##        for i in range(len(self.bitstrings)):
##            prob_dict[self.bitstrings[i]] = np.abs(state_vec[0][i]) ** 2
##        qbit_states = list(prob_dict.keys())
##        print(type(measure_results))
        #qbit_states = [result.keys() for result in measure_results]
##        for bitkey in qbit_states:
##            if bitkey[3] == '0' and bitkey[4] == '0':
##                measure_counts[0] += prob_dict[bitkey]
##            elif bitkey[3] == '0' and bitkey[4] == '1':
##                measure_counts[1] += prob_dict[bitkey]
##            elif bitkey[3] == '1' and bitkey[4] == '0':
##                measure_counts[2] += prob_dict[bitkey]
##            else:
##                measure_counts[3] += prob_dict[bitkey]
##        measure_norm = np.linalg.norm(measure_counts)
##        measure_counts = measure_counts / measure_norm
##        if check:
##            print("Measure outcomes: ")
##            print(measure_counts)
        #measure_weights = torch.tensor(measure_counts, requires_grad=True)
        obs_1 = expval_joint_analytical(self.q_device, "ZZZXXZZZ")
        obs_2 = expval_joint_analytical(self.q_device, "ZZZYYZZZ")
        obs_3 = expval_joint_analytical(self.q_device, "ZZZYXZZZ")
        obs_4 = expval_joint_analytical(self.q_device, "ZZZXYZZZ")
        expectations = torch.stack([obs_1, obs_2, obs_3, obs_4], dim=1)
        #measure_weights = self.smx(measure_results)
        if check:
            print("Measure weights: ")
        #print("Gradient preserved: " + str(measure_weights.requires_grad))
        measure_weights = expectations.view(4)
        #print("Output shape: " + str(measure_weights.shape))
##        if check:
##            print(measure_weights)
        #print("Measure results")
        return measure_weights

def square_loss(labels, predictions):
    loss = 0
    for l, p in zip(labels, predictions):
##        print(type(l))
##        print(type(p))
        loss = loss + ((l - p) ** 2)
    loss = loss / len(labels)
    return loss

def epsilon_greedy(TreeTensor, epsilon, s, n_actions, timestep, rgen, check=False, train=False):
##    seed = int(time.time())
##    rng = np.random.default_rng(seed)
    if train or rgen.random() < ((epsilon / n_actions) + (1 - epsilon)):
        with torch.no_grad():
            measurements = TreeTensor(s, check=check)
            action = torch.argmax(measurements)
            if check:
                print("Argmax result: " + str(action))
            return action

##        seedval = int(time.time())
##        np.random.seed(seedVal)
        action = rgen.integers(0, high=n_actions)
##        if check:
##            print(choices)
        #action = np.bincount(choices).argmax()
##        if check:
##            print(action)

        action = torch.tensor(action)
        return action

def cost(model, features, labels, dev):
    loss_func = nn.SmoothL1Loss()
    predictions = [model(item.state)[item.action] for item in features]
    loss_total = loss_func(torch.tensor(labels, requires_grad=True, device=dev), torch.tensor(predictions, requires_grad=True, device=dev))
    return loss_total

def ttn_train(env_name, model, alpha, gamma, epsilon, episodes, max_steps, n_actions, top_dev, opt, sched, render=True):

    act_range = [0, 1, 2, 6]
    logging.basicConfig(filename="ExperimentDebug1.txt", level=logging.DEBUG)
##    use_cuda = torch.cuda.is_available()
##    main_device = torch.device("cuda" if use_cuda else "cpu")
##    circuit =

    param_file = "TTN_params.bin"
    scores = []
    target_update = 20
    batch_size = 100
    optimize_steps = 5
    target_update_counter = 0
    iter_index = []
    iter_reward = []
    iter_total_steps = []
    cost_list = []
    timestep_reward = []
    seed = int(time.time())
    rng = np.random.default_rng(seed)
    memory = ReplayMemory(500)
##    seedVal = int(time.time())
##    np.random.seed(seedVal)
    #q_device = tq.QuantumDevice(n_wires=8)
    #optimizer = optim.Adam(model.parameters(), lr=alpha, weight_decay=1e-4)
    #optimizer_mps = optim.Adam(model.mps.parameters(), lr=alpha, weight_decay=1e-4)
##    optimizer = optim.SGD(model.parameters(), lr=alpha, momentum=0.9)
##    optimizer_mps = optim.SGD(model.parameters(), lr=alpha, momentum=0.9)
    #scheduler = CosineAnnealingLR(optimizer, T_max=episodes)
    #scheduler_mps = CosineAnnealingLR(optimizer_mps, T_max=episodes)
    env = gym.make(env_name, max_episode_steps=max_steps, disable_env_checker=True, render_mode="human")
    env = ImgObsFlatWrapper(env)
    env_record = RecordVideo(env, f"video/TTNMinigridTraining")
##    record_dict = env_record.__dict__
##    record_keys = record_dict.keys()
##    print(record_keys)
    start_state = None
    start_time = time.asctime()
    for episode in range(episodes):
        print("Episode: " + str(episode))
        t = 0
        total_reward = 0
        #print("Reset reward")
        done = False
        #print("Not done")
        #rgen = np.random.RandomState(seedVal)
        seedVal = int(time.time())
##        print(type(observation))
##        print(observation)
        #print("Reset complete")
        if episode == 0:
            start_state = env_record.env.grid
            #print("Start state set")
            env_record.env.grid = start_state
            #print("Start state retrieved")
        if render:
        #print("Number of obstacles: " + str(len(env_record.env.obstacles)))
        observation = env_record.env.gen_obs()
        #print("Got observation")
#        print(observation)
        observation = torch.tensor(observation['image']).type('torch.FloatTensor').view(147, 1).to(top_dev)
        print("Observation shape: " + str(observation.shape))
        #print("Observation formatted")
        observation.requires_grad = True
        #observation =
        act = epsilon_greedy(model, epsilon, observation, n_actions, t, rng, check=True)
        #print("Action index: " + str(act))
##        print("Got action")
        action = act_range[act]
        #print("Action selection: " + str(action))
        while t < max_steps:
            print("Episode: " + str(episode) + " , " + "Timestep: " + str(t))
            if render:
            #print("Time Step: " + str(t))
            t += 1
            target_update_counter += 1
            seedVal = int(time.time())
            next_obs, reward, done, _, info = env_record.step(action)
            print("Step reward: " + str(reward))
            #print("Step reward: " + str(reward))
            next_obs = torch.tensor(next_obs).type('torch.FloatTensor').view(147, 1).to(top_dev)
            next_obs.requires_grad = True
            #new_obs =
            total_reward += reward
            act_ = epsilon_greedy(model, epsilon, next_obs, n_actions, t, rng, check=True)
            #print("Action index: " + str(act))
            action_ = act_range[act_]
            #print("Action selection: " + str(action))
            memory.push(observation, act, reward, next_obs, done)
            if len(memory) > batch_size and done:
                batch_sampled = memory.sample(batch_size)
##                batch = Transition(*zip(*transitions))
##                non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool)
##                non_final_next_states =[s for s in batch.next_state
##                                                if s is not None])
##                state_batch =
##                action_batch =
##                reward_batch =
                Qtarget = [item.reward + (1 - int(item.done)) * gamma * torch.max(model(item.next_state)) for item in batch_sampled]
                loss = cost(model, batch_sampled, Qtarget, top_dev)
                #print("Loss Gradient Function: " + str(loss.grad_fn))
##                grads = torch.autograd.grad(loss, list(model.parameters()), allow_unused=True)
##                print("Grad type: " + str(grads))
                #print("Loss: " + str(loss))

##                loss.backward()
##                for param in model.parameters():
##                    print("Parameter gradient: " + str(param.grad))
                #print("Optimization step")
##            current_replay_memory = memory.output_all()
##            current_target_for_replay_memory = [item.reward + (1 - int(item.done)) * gamma * torch.max(model(item.next_state)) for item in current_replay_memory]
            if target_update_counter >= target_update:
                target_update_counter = 0
            observation, action = next_obs, action_
            if done or t == max_steps:
                epsilon = epsilon / ((episode / 750) + 1)
                alpha = 0.95 * alpha
                print("Reward data length: " + str(len(timestep_reward)))
    stop_time = time.asctime()
    print("Start time: ")
    print("Stop time: ")
    print(stop_time), param_file)
    return timestep_reward, iter_index, iter_reward, iter_total_steps

def test_agent(model, env_folder, epsilon, env_name, config_name, n_tests, max_steps, delay=1):
    act_range = [0, 1, 2, 6]
    n_successes = 0
    test_rewards = []
##    use_cuda = torch.cuda.is_available()
##    main_device = torch.device("cuda" if use_cuda else "cpu")
##    circuit =
    env = gym.make(env_name, max_episode_steps=max_steps, render_mode="human", height=64, width=64)
    env = SymbolicObsWrapper(env)
    env = ImgObsFlatWrapper(env)
    env_record = RecordVideo(env, f"video/TTNMinigridTraining")
    done = False
    for test in range(n_tests):
        reward_total = 0
        epsilon = 0
        filename = env_folder + "/" + env_name + "_" + str(test)
        statefile = open(filename, "wb")
        state_data = pkl.load(statefile)
        new_grid = env_record.env.grid.decode(state_data)
        env_record.env.grid = new_grid
        while True:
            s = torch.tensor(observation).type('torch.FloatTensor').view(1, -1)
            act = epsilon_greedy(model, epsilon, observation)
            a = act_range[act]
            next_obs, reward, done, info = env_record.step(a)
            next_obs = torch.tensor(next_obs).type('torch.FloatTensor').view(1, -1)
            reward_total += reward
            if done:
                if reward > 0:
                    n_successes += 1
                    print("Goal Reached")
                    print("Task Failed")
    return test_rewards, n_sucesses

def main():
##    register(
##        id="Minigrid-RandomLava-6Spots-v0",
##        entry_point="RandomLavaMinigrid:RandomLavaEnv",
##        kwargs={"size": 8, "n_obstacles": 6})
    #env_name = "Minigrid-RandomLava-6Spots-v0"
    env_name = "MiniGrid-Empty-8x8-v0"
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    torch.backends.cudnn.benchmark = False
    #device = torch.device("cpu")
    alpha = 0.4
    gamma = 0.5
    epsilon = 1
    episodes = 1000
    max_steps = 100
    n_actions = 4
    seed = 42
    model = TreeTensorAgent(147).to(device)
    print("Model graph: " + str(model.graph))

    tn_opt = optim.Adam(model.parameters(), lr=5e-3, weight_decay=1e-4)
    scheduler = CosineAnnealingLR(tn_opt, T_max=episodes)
    timestep_reward, iter_index, iter_reward, iter_total_steps = ttn_train(env_name, model, alpha, gamma, epsilon, episodes, max_steps, n_actions, device, tn_opt, scheduler)
    x_vals = np.arange(episodes)
    y_vals = np.asarray(timestep_reward)
    fig, ax = plt.subplots()
    ax.plot(x_vals, y_vals)
    ax.set(xlabel="Episode", ylabel="Total Score", title="Deep Quantum TTN Learning Training Process: 6-Site Random Lava Minigrid")

if __name__ == "__main__":

I would greatly appreciate your help.

JustinS6626 commented 1 year ago

I just notice that the code that I posted has a bad import. This one should replicate the issue:

import math
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn as nn
import torch.nn.functional as F
import time
import datetime
import calendar
import random
from minigrid.wrappers import *
import logging
from torchpack.callbacks import (InferenceRunner, MeanAbsoluteError,
                                 MaxSaver, MinSaver,
                                 Saver, SaverRestore, CategoricalAccuracy)
from torchpack.environ import set_run_dir
from torchpack.utils.config import configs
from torchpack.utils.logging import logger
from torchtest import assert_vars_change
from torch.nn.parameter import Parameter
import torchquantum as tq
import torchquantum.functional as tqf
from torchquantum.measurement import *
import matplotlib.pyplot as plt
import pickle as pkl
import gymnasium as gym
from gymnasium.wrappers.record_video import RecordVideo
from collections import namedtuple, deque
#from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistic
import numpy as np
import os
from gymnasium.envs.registration import *
##from ._utils import _import_dotted_name
##from ._six import string_classes as _string_classes
##from torch._sources import get_source_lines_and_file
##from torch.types import Storage
##from import _get_dtype_from_pickle_storage_type
##from typing_extensions import TypeAlias
##import copyreg

#os.environ["SDL_VIDEODRIVER"] = "dummy"

class ReplayMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        if len(self.memory) < self.capacity:
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def output_all(self):
        return self.memory

    def __len__(self):
        return len(self.memory)

Transition = namedtuple('Transition',
                        ('state', 'action', 'reward', 'next_state', 'done'))

def SO4(q_device, RY, RZ, CNOT, wires, static=None, parent_graph=None):
##    rz_pi = np.asarray([[np.exp(-1j * (np.pi / 4)), 0],
##                        [0, np.exp(1j * (np.pi / 4))]])
##    rz_neg_pi = np.asarray([[np.exp(1j * (np.pi / 4)), 0],
##                        [0, np.exp(-1j * (np.pi / 4))]])
##    ry_pi = np.asarray([[np.cos(np.pi / 4), -1 * np.sin(np.pi / 4)],
##                        [np.sin(np.pi / 4), np.cos(np.pi / 4)]])
##    ry_neg_pi = np.asarray([[np.cos(-1 * np.pi / 4), -1 * np.sin(-1 * np.pi / 4)],
##                            [np.sin(-1 * np.pi / 4), np.cos(-1 * np.pi / 4)]])
##    tqf.qubitunitary(device, wires=wires[0], params=rz_pi)
##    tqf.qubitunitary(device, wires=wires[1], params=rz_pi)
##    tqf.qubitunitary(device, wires=wires[1], params=ry_pi)
    tqf.rz(q_device, wires=wires[0], params=torch.tensor([np.pi / 2]), static=static_mode, parent_graph=graph)
    tqf.rz(q_device, wires=wires[1], params=torch.tensor([np.pi / 2]), static=static_mode, parent_graph=graph)
    tqf.ry(q_device, wires=wires[1], params=torch.tensor([np.pi / 2]), static=static_mode, parent_graph=graph)
    tqf.cnot(q_device, wires=[wires[1], wires[0]], static=static_mode, parent_graph=graph)
    RZ[0](q_device, wires=wires[0])
    RZ[1](q_device, wires=wires[1])
    RY[0](q_device, wires=wires[0])
    RY[1](q_device, wires=wires[1])
    RZ[2](q_device, wires=wires[0])
    RZ[3](q_device, wires=wires[1])
    tqf.cnot(q_device, wires=[wires[1], wires[0]], static=static_mode, parent_graph=graph)
    tqf.ry(q_device, wires=wires[1], params=torch.tensor([-np.pi / 2]), static=static_mode, parent_graph=graph)
    tqf.rz(q_device, wires=wires[0], params=torch.tensor([-np.pi / 2]), static=static_mode, parent_graph=graph)
    tqf.rz(q_device, wires=wires[1], params=torch.tensor([-np.pi / 2]), static=static_mode, parent_graph=graph)
##    tqf.qubitunitary(device, wires=wires[1], params=ry_neg_pi)
##    tqf.qubitunitary(device, wires=wires[0], params=rz_neg_pi)
##    tqf.qubitunitary(device, wires=wires[1], params=rz_neg_pi)

class TreeTensorAgent(tq.QuantumModule):
    class QLayer(tq.QuantumModule):
        def __init__(self):
            self.n_wires = 8
            self.n_actions = 4
##                self.q_device = tq.QuantumDevice(n_wires=self.n_wires)

            #self.bias = torch.tensor(np.random.rand(4), requires_grad=True)
            self.rz_0_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_0_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_0_0 = tq.RY(has_params=True, trainable=True)
            self.ry_0_1 = tq.RY(has_params=True, trainable=True)
            self.rz_0_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_0_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_1_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_1_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_1_0 = tq.RY(has_params=True, trainable=True)
            self.ry_1_1 = tq.RY(has_params=True, trainable=True)
            self.rz_1_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_1_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_2_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_2_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_2_0 = tq.RY(has_params=True, trainable=True)
            self.ry_2_1 = tq.RY(has_params=True, trainable=True)
            self.rz_2_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_2_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_3_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_3_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_3_0 = tq.RY(has_params=True, trainable=True)
            self.ry_3_1 = tq.RY(has_params=True, trainable=True)
            self.rz_3_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_3_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_4_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_4_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_4_0 = tq.RY(has_params=True, trainable=True)
            self.ry_4_1 = tq.RY(has_params=True, trainable=True)
            self.rz_4_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_4_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_5_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_5_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_5_0 = tq.RY(has_params=True, trainable=True)
            self.ry_5_1 = tq.RY(has_params=True, trainable=True)
            self.rz_5_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_5_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_6_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_6_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_6_0 = tq.RY(has_params=True, trainable=True)
            self.ry_6_1 = tq.RY(has_params=True, trainable=True)
            self.rz_6_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_6_3 = tq.RZ(has_params=True, trainable=True)
            self.cnot = tq.CNOT(has_params=False, trainable=False)

        def forward(self, q_device):
            self.q_device = q_device
            #SO4(self.q_device, [self.ry_0_0, self.ry_0_1], [self.rz_0_0, self.rz_0_1, self.rz_0_2, self.rz_0_3], self.cnot, [0, 1], static=static_mode_mode, parent_graph=graph)
            #Layer 1 Gate 1 Start
            tqf.rz(q_device, wires=0, params=torch.tensor([np.pi / 2]))
            tqf.rz(q_device, wires=1, params=torch.tensor([np.pi / 2]))
            tqf.ry(q_device, wires=1, params=torch.tensor([np.pi / 2]))
            tqf.cnot(q_device, wires=[1, 0])
            self.rz_0_0(q_device, wires=0)
            self.rz_0_1(q_device, wires=1)
            self.ry_0_0(q_device, wires=0)
            self.ry_0_1(q_device, wires=1)
            self.rz_0_2(q_device, wires=0)
            self.rz_0_3(q_device, wires=1)
            tqf.cnot(q_device, wires=[1, 0])
            tqf.ry(q_device, wires=1, params=torch.tensor([-np.pi / 2]))
            tqf.rz(q_device, wires=0, params=torch.tensor([-np.pi / 2]))
            tqf.rz(q_device, wires=1, params=torch.tensor([-np.pi / 2]))
            #Layer 1 Gate 1 End

            #SO4(self.q_device, [self.ry_1_0, self.ry_1_1], [self.rz_1_0, self.rz_1_1, self.rz_1_2, self.rz_1_3], self.cnot, [2, 3], static=static_mode_mode, parent_graph=graph)
            #Layer 1 Gate 2 Start
            tqf.rz(q_device, wires=2, params=torch.tensor([np.pi / 2]))
            tqf.rz(q_device, wires=3, params=torch.tensor([np.pi / 2]))
            tqf.ry(q_device, wires=3, params=torch.tensor([np.pi / 2]))
            tqf.cnot(q_device, wires=[3, 2])
            self.rz_1_0(q_device, wires=2)
            self.rz_1_1(q_device, wires=3)
            self.ry_1_0(q_device, wires=2)
            self.ry_1_1(q_device, wires=3)
            self.rz_1_2(q_device, wires=2)
            self.rz_1_3(q_device, wires=3)
            tqf.cnot(q_device, wires=[3, 2])
            tqf.ry(q_device, wires=3, params=torch.tensor([-np.pi / 2]))
            tqf.rz(q_device, wires=2, params=torch.tensor([-np.pi / 2]))
            tqf.rz(q_device, wires=3, params=torch.tensor([-np.pi / 2]))
            #Layer 1 Gate 2 End

            #SO4(self.q_device, [self.ry_2_0, self.ry_2_1], [self.rz_2_0, self.rz_2_1, self.rz_2_2, self.rz_2_3], self.cnot, [4, 5], static=static_mode_mode, parent_graph=graph)
            #Layer 1 Gate 3 Start
            tqf.rz(q_device, wires=4, params=torch.tensor([np.pi / 2]))
            tqf.rz(q_device, wires=5, params=torch.tensor([np.pi / 2]))
            tqf.ry(q_device, wires=5, params=torch.tensor([np.pi / 2]))
            tqf.cnot(q_device, wires=[5, 4])
            self.rz_2_0(q_device, wires=4)
            self.rz_2_1(q_device, wires=5)
            self.ry_2_0(q_device, wires=4)
            self.ry_2_1(q_device, wires=5)
            self.rz_2_2(q_device, wires=4)
            self.rz_2_3(q_device, wires=5)
            tqf.cnot(q_device, wires=[5, 4])
            tqf.ry(q_device, wires=5, params=torch.tensor([-np.pi / 2]))
            tqf.rz(q_device, wires=4, params=torch.tensor([-np.pi / 2]))
            tqf.rz(q_device, wires=5, params=torch.tensor([-np.pi / 2]))
            #Layer 1 Gate 3 End

            #SO4(self.q_device, [self.ry_3_0, self.ry_3_1], [self.rz_3_0, self.rz_3_1, self.rz_3_2, self.rz_3_3], self.cnot, [6, 7], static=static_mode_mode, parent_graph=graph)
            #Layer 1 Gate 4 Start
            tqf.rz(q_device, wires=6, params=torch.tensor([np.pi / 2]))
            tqf.rz(q_device, wires=7, params=torch.tensor([np.pi / 2]))
            tqf.ry(q_device, wires=7, params=torch.tensor([np.pi / 2]))
            tqf.cnot(q_device, wires=[7, 6])
            self.rz_3_0(q_device, wires=6)
            self.rz_3_1(q_device, wires=7)
            self.ry_3_0(q_device, wires=6)
            self.ry_3_1(q_device, wires=7)
            self.rz_3_2(q_device, wires=6)
            self.rz_3_3(q_device, wires=7)
            tqf.cnot(q_device, wires=[7, 6])
            tqf.ry(q_device, wires=7, params=torch.tensor([-np.pi / 2]))
            tqf.rz(q_device, wires=6, params=torch.tensor([-np.pi / 2]))
            tqf.rz(q_device, wires=7, params=torch.tensor([-np.pi / 2]))
            #Layer 1 Gate 4 End

            #SO4(self.q_device, [self.ry_4_0, self.ry_4_1], [self.rz_4_0, self.rz_4_1, self.rz_4_2, self.rz_4_3], self.cnot, [1, 2], static=static_mode_mode, parent_graph=graph)
            #Layer 2 Gate 1 Start
            tqf.rz(q_device, wires=1, params=torch.tensor([np.pi / 2]))
            tqf.rz(q_device, wires=2, params=torch.tensor([np.pi / 2]))
            tqf.ry(q_device, wires=2, params=torch.tensor([np.pi / 2]))
            tqf.cnot(q_device, wires=[2, 1])
            self.rz_4_0(q_device, wires=1)
            self.rz_4_1(q_device, wires=2)
            self.ry_4_0(q_device, wires=1)
            self.ry_4_1(q_device, wires=2)
            self.rz_4_2(q_device, wires=1)
            self.rz_4_3(q_device, wires=2)
            tqf.cnot(q_device, wires=[2, 1])
            tqf.ry(q_device, wires=2, params=torch.tensor([-np.pi / 2]))
            tqf.rz(q_device, wires=1, params=torch.tensor([-np.pi / 2]))
            tqf.rz(q_device, wires=2, params=torch.tensor([-np.pi / 2]))
            #Layer 2 Gate 1 End

            #SO4(self.q_device, [self.ry_5_0, self.ry_5_1], [self.rz_5_0, self.rz_5_1, self.rz_5_2, self.rz_5_3], self.cnot, [5, 6], static=static_mode_mode, parent_graph=graph)
            #Layer 2 Gate 2 Start
            tqf.rz(q_device, wires=5, params=torch.tensor([np.pi / 2]))
            tqf.rz(q_device, wires=6, params=torch.tensor([np.pi / 2]))
            tqf.ry(q_device, wires=6, params=torch.tensor([np.pi / 2]))
            tqf.cnot(q_device, wires=[6, 5])
            self.rz_5_0(q_device, wires=5)
            self.rz_5_1(q_device, wires=6)
            self.ry_5_0(q_device, wires=5)
            self.ry_5_1(q_device, wires=6)
            self.rz_5_2(q_device, wires=5)
            self.rz_5_3(q_device, wires=6)
            tqf.cnot(q_device, wires=[6, 5])
            tqf.ry(q_device, wires=6, params=torch.tensor([-np.pi / 2]))
            tqf.rz(q_device, wires=5, params=torch.tensor([-np.pi / 2]))
            tqf.rz(q_device, wires=6, params=torch.tensor([-np.pi / 2]))
            #Layer 2 Gate 2 End
            #SO4(self.q_device, [self.ry_6_0, self.ry_6_1], [self.rz_6_0, self.rz_6_1, self.rz_6_2, self.rz_6_3], self.cnot, [2, 5], static=static_mode_mode, parent_graph=graph)
            #Layer 3 Gate 1 Start
            tqf.rz(q_device, wires=2, params=torch.tensor([np.pi / 2]))
            tqf.rz(q_device, wires=5, params=torch.tensor([np.pi / 2]))
            tqf.ry(q_device, wires=5, params=torch.tensor([np.pi / 2]))
            tqf.cnot(q_device, wires=[5, 2])
            self.rz_6_0(q_device, wires=0)
            self.rz_6_1(q_device, wires=1)
            self.ry_6_0(q_device, wires=0)
            self.ry_6_1(q_device, wires=1)
            self.rz_6_2(q_device, wires=0)
            self.rz_6_3(q_device, wires=1)
            tqf.cnot(q_device, wires=[5, 2])
            tqf.ry(q_device, wires=5, params=torch.tensor([-np.pi / 2]))
            tqf.rz(q_device, wires=2, params=torch.tensor([-np.pi / 2]))
            tqf.rz(q_device, wires=5, params=torch.tensor([-np.pi / 2]))
            #Layer 3 Gate 1 End
    def __init__(self, input_size):
        self.n_wires = 8
        self.n_actions = 4
        self.input_size = input_size
        self.q_layer = self.QLayer()
        #self.bias = Parameter(torch.zeros(self.n_actions))
        self.smx = nn.Softmax()
        self.bitstrings = gen_bitstrings(self.n_wires)
        self.q_device = tq.QuantumDevice(n_wires=self.n_wires)
        #self.mps = MPS(input_dim = 147, output_dim = 8, bond_dim = 2, feature_dim = 2, use_GPU = False, parallel = True, init_std=1e-2)
##        self.feature_map = nn.Sequential(nn.Linear(self.input_size, 64), nn.ReLU(),
##                                         nn.Conv1d(self.input_size, 64, kernel_size=2, stride=2), nn.ReLU(),
##                                         nn.Conv1d(64, 1, kernel_size=2, stride=3), nn.Tanh())
        self.layer_1 = nn.Linear(1, 256, bias=False)
        #147 x 128
        self.layer_2 = nn.ReLU()
        self.layer_3 = nn.Conv1d(self.input_size, 256, kernel_size=2, padding=1, dilation=2, bias=False)
        self.layer_4 = nn.ReLU()
        self.layer_5 = nn.Conv1d(256, 1, kernel_size=4, stride=8, padding=1, dilation=3, bias=False)
        self.layer_6 = nn.CELU()
##        for param in self.mps.parameters():
##            print("Gradient required: " + str(param.requires_grad))

##        self.measure = tq.MeasureMultiPauliSum(
##            obs_list=[{"wires" : [3, 4],
##                       "observables" : ['z', 'z'],
##                       "coefficient" : [1, 1]}])

##        self.encoder1 = tq.GeneralEncoder(
##            [{"input_idx" : [0], "func" : "ry", "wires" : [0]},
##             {"input_idx" : [1], "func" : "ry", "wires" : [1]},
##             {"input_idx" : [2], "func" : "ry", "wires" : [2]},
##             {"input_idx" : [3], "func" : "ry", "wires" : [3]},
##             {"input_idx" : [4], "func" : "ry", "wires" : [4]},
##             {"input_idx" : [5], "func" : "ry", "wires" : [5]},
##             {"input_idx" : [6], "func" : "ry", "wires" : [6]},
##             {"input_idx" : [7], "func" : "ry", "wires" : [7]}])
##        self.encoder2 = tq.GeneralEncoder(
##            [{"input_idx" : [0], "func" : "rz", "wires" : [0]},
##             {"input_idx" : [1], "func" : "rz", "wires" : [1]},
##             {"input_idx" : [2], "func" : "rz", "wires" : [2]},
##             {"input_idx" : [3], "func" : "rz", "wires" : [3]},
##             {"input_idx" : [4], "func" : "rz", "wires" : [4]},
##             {"input_idx" : [5], "func" : "rz", "wires" : [5]},
##             {"input_idx" : [6], "func" : "rz", "wires" : [6]},
##             {"input_idx" : [7], "func" : "rz", "wires" : [7]}])

            [{"input_idx" : [0], "func" : "ry", "wires" : [0]},
             {"input_idx" : [1], "func" : "ry", "wires" : [1]},
             {"input_idx" : [2], "func" : "ry", "wires" : [2]},
             {"input_idx" : [3], "func" : "ry", "wires" : [3]},
             {"input_idx" : [4], "func" : "ry", "wires" : [4]},
             {"input_idx" : [5], "func" : "ry", "wires" : [5]},
             {"input_idx" : [6], "func" : "ry", "wires" : [6]},
             {"input_idx" : [7], "func" : "ry", "wires" : [7]},
             {"input_idx" : [8], "func" : "rz", "wires" : [0]},
             {"input_idx" : [9], "func" : "rz", "wires" : [1]},
             {"input_idx" : [10], "func" : "rz", "wires" : [2]},
             {"input_idx" : [11], "func" : "rz", "wires" : [3]},
             {"input_idx" : [12], "func" : "rz", "wires" : [4]},
             {"input_idx" : [13], "func" : "rz", "wires" : [5]},
             {"input_idx" : [14], "func" : "rz", "wires" : [6]},
             {"input_idx" : [15], "func" : "rz", "wires" : [7]},
             {"input_idx" : [16], "func" : "rx", "wires" : [0]},
             {"input_idx" : [17], "func" : "rx", "wires" : [1]},
             {"input_idx" : [18], "func" : "rx", "wires" : [2]},
             {"input_idx" : [19], "func" : "rx", "wires" : [3]},
             {"input_idx" : [20], "func" : "rx", "wires" : [4]},
             {"input_idx" : [21], "func" : "rx", "wires" : [5]},
             {"input_idx" : [22], "func" : "rx", "wires" : [6]},
             {"input_idx" : [23], "func" : "rx", "wires" : [7]},
             {"input_idx" : [24], "func" : "rz", "wires" : [0]},
             {"input_idx" : [25], "func" : "rz", "wires" : [1]},
             {"input_idx" : [26], "func" : "rz", "wires" : [2]},
             {"input_idx" : [27], "func" : "rz", "wires" : [3]},
             {"input_idx" : [28], "func" : "rz", "wires" : [4]},
             {"input_idx" : [29], "func" : "rz", "wires" : [5]},
             {"input_idx" : [30], "func" : "rz", "wires" : [6]},
             {"input_idx" : [31], "func" : "rz", "wires" : [7]}])

##    def get_angles_atan(self, in_x):
##        angles = torch.stack([torch.stack([torch.atan(item), torch.atan(item**2)]) for item in in_x])
##        return angles

    def forward(self, input_data, check=False):
        #measure_counts = np.zeros(self.n_actions)
        prob_dict = {}
        #x = self.feature_map(input_data)
        x_1 = self.layer_1(input_data)
        x_2 = self.layer_2(x_1)
        #print("Stage one size: " + str(x_2.shape))
        x_3 = self.layer_3(x_2)
        x_4 = self.layer_4(x_3)
        #print("Stage two size: " + str(x_3.shape))
        x_5 = self.layer_5(x_4)
        x_6 = self.layer_6(x_5)
        #print("Stage three size " + str(x_6.shape))
        #x_angles = self.get_angles_atan(x)
        #x_angles = torch.stack([torch.atan(x), torch.atan(x ** 2)])
##        print("Angle array shape: " + str(x_angles.shape))
##        print("Gradient preserved: " + str(x_angles.requires_grad))
        #torch.reshape(x_angles, (1, 16))
##        new_x_angles = x_angles.view(1, 16)
##        x_angles = new_x_angles
##        print("Input shape: " + str(x_angles.shape))
##        print("Gradient preserved: " + str(x_angles.requires_grad))
##        if check:
##            print(x_angles)
        x_angles = torch.atan(x_6)
        for i in range(self.n_wires):

            tqf.hadamard(self.q_device, wires=i)

##        self.encoder1(self.q_device, x_angles[0][0])
##        self.encoder2(self.q_device, x_angles[1][0])
        self.encoder(self.q_device, x_angles)
        #print("Parent graph: " + str(self.graph))

##        SO4(self.q_device, [self.ry_0_0, self.ry_0_1], [self.rz_0_0, self.rz_0_1, self.rz_0_2, self.rz_0_3], self.cnot, [0, 1])
##        SO4(self.q_device, [self.ry_1_0, self.ry_1_1], [self.rz_1_0, self.rz_1_1, self.rz_1_2, self.rz_1_3], self.cnot, [2, 3])
##        SO4(self.q_device, [self.ry_2_0, self.ry_2_1], [self.rz_2_0, self.rz_2_1, self.rz_2_2, self.rz_2_3], self.cnot, [4, 5])
##        SO4(self.q_device, [self.ry_3_0, self.ry_3_1], [self.rz_3_0, self.rz_3_1, self.rz_3_2, self.rz_3_3], self.cnot, [6, 7])
##        SO4(self.q_device, [self.ry_4_0, self.ry_4_1], [self.rz_4_0, self.rz_4_1, self.rz_4_2, self.rz_4_3], self.cnot, [1, 2])
##        SO4(self.q_device, [self.ry_5_0, self.ry_5_1], [self.rz_5_0, self.rz_5_1, self.rz_5_2, self.rz_5_3], self.cnot, [5, 6])
##        SO4(self.q_device, [self.ry_6_0, self.ry_6_1], [self.rz_6_0, self.rz_6_1, self.rz_6_2, self.rz_6_3], self.cnot, [2, 5])
        #print("ops done")
##        device_states = self.q_device.get_states_1d()
##        #print(device_states)
##        circuit_state = tq.QuantumState(n_wires=self.n_wires)
##        circuit_state.set_states(device_states)
        #state_vec = self.q_device.get_states_1d().abs().detach().cpu().numpy()
##        print("State vector: ")
##        print(state_vec)
##        print(state_vec.shape)
        #measures = tq.measure(self.q_device, n_shots=4096)
##        for i in range(len(self.bitstrings)):
##            prob_dict[self.bitstrings[i]] = np.abs(state_vec[0][i]) ** 2
##        qbit_states = list(prob_dict.keys())
##        print(type(measure_results))
        #qbit_states = [result.keys() for result in measure_results]
##        for bitkey in qbit_states:
##            if bitkey[3] == '0' and bitkey[4] == '0':
##                measure_counts[0] += prob_dict[bitkey]
##            elif bitkey[3] == '0' and bitkey[4] == '1':
##                measure_counts[1] += prob_dict[bitkey]
##            elif bitkey[3] == '1' and bitkey[4] == '0':
##                measure_counts[2] += prob_dict[bitkey]
##            else:
##                measure_counts[3] += prob_dict[bitkey]
##        measure_norm = np.linalg.norm(measure_counts)
##        measure_counts = measure_counts / measure_norm
##        if check:
##            print("Measure outcomes: ")
##            print(measure_counts)
        #measure_weights = torch.tensor(measure_counts, requires_grad=True)
        obs_1 = expval_joint_analytical(self.q_device, "ZZZXXZZZ")
        obs_2 = expval_joint_analytical(self.q_device, "ZZZYYZZZ")
        obs_3 = expval_joint_analytical(self.q_device, "ZZZYXZZZ")
        obs_4 = expval_joint_analytical(self.q_device, "ZZZXYZZZ")
        expectations = torch.stack([obs_1, obs_2, obs_3, obs_4], dim=1)
        #measure_weights = self.smx(measure_results)
        if check:
            print("Measure weights: ")
        #print("Gradient preserved: " + str(measure_weights.requires_grad))
        measure_weights = expectations.view(4)
        #print("Output shape: " + str(measure_weights.shape))
##        if check:
##            print(measure_weights)
        #print("Measure results")
        return measure_weights

def square_loss(labels, predictions):
    loss = 0
    for l, p in zip(labels, predictions):
##        print(type(l))
##        print(type(p))
        loss = loss + ((l - p) ** 2)
    loss = loss / len(labels)
    return loss

def epsilon_greedy(TreeTensor, epsilon, s, n_actions, timestep, rgen, check=False, train=False):
##    seed = int(time.time())
##    rng = np.random.default_rng(seed)
    if train or rgen.random() < ((epsilon / n_actions) + (1 - epsilon)):
        with torch.no_grad():
            measurements = TreeTensor(s, check=check)
            action = torch.argmax(measurements)
            if check:
                print("Argmax result: " + str(action))
            return action

##        seedval = int(time.time())
##        np.random.seed(seedVal)
        action = rgen.integers(0, high=n_actions)
##        if check:
##            print(choices)
        #action = np.bincount(choices).argmax()
##        if check:
##            print(action)

        action = torch.tensor(action)
        return action

def cost(model, features, labels, dev):
    loss_func = nn.SmoothL1Loss()
    predictions = [model(item.state)[item.action] for item in features]
    loss_total = loss_func(torch.tensor(labels, requires_grad=True, device=dev), torch.tensor(predictions, requires_grad=True, device=dev))
    return loss_total

def ttn_train(env_name, model, alpha, gamma, epsilon, episodes, max_steps, n_actions, top_dev, opt, sched, render=True):

    act_range = [0, 1, 2, 6]
    logging.basicConfig(filename="ExperimentDebug1.txt", level=logging.DEBUG)
##    use_cuda = torch.cuda.is_available()
##    main_device = torch.device("cuda" if use_cuda else "cpu")
##    circuit =

    param_file = "TTN_params.bin"
    scores = []
    target_update = 20
    batch_size = 100
    optimize_steps = 5
    target_update_counter = 0
    iter_index = []
    iter_reward = []
    iter_total_steps = []
    cost_list = []
    timestep_reward = []
    seed = int(time.time())
    rng = np.random.default_rng(seed)
    memory = ReplayMemory(500)
##    seedVal = int(time.time())
##    np.random.seed(seedVal)
    #q_device = tq.QuantumDevice(n_wires=8)
    #optimizer = optim.Adam(model.parameters(), lr=alpha, weight_decay=1e-4)
    #optimizer_mps = optim.Adam(model.mps.parameters(), lr=alpha, weight_decay=1e-4)
##    optimizer = optim.SGD(model.parameters(), lr=alpha, momentum=0.9)
##    optimizer_mps = optim.SGD(model.parameters(), lr=alpha, momentum=0.9)
    #scheduler = CosineAnnealingLR(optimizer, T_max=episodes)
    #scheduler_mps = CosineAnnealingLR(optimizer_mps, T_max=episodes)
    env = gym.make(env_name, max_episode_steps=max_steps, disable_env_checker=True, render_mode="human")
    env = ImgObsWrapper(env)
    env_record = RecordVideo(env, f"video/TTNMinigridTraining")
##    record_dict = env_record.__dict__
##    record_keys = record_dict.keys()
##    print(record_keys)
    start_state = None
    start_time = time.asctime()
    for episode in range(episodes):
        print("Episode: " + str(episode))
        t = 0
        total_reward = 0
        #print("Reset reward")
        done = False
        #print("Not done")
        #rgen = np.random.RandomState(seedVal)
        seedVal = int(time.time())
##        print(type(observation))
##        print(observation)
        #print("Reset complete")
        if episode == 0:
            start_state = env_record.env.grid
            #print("Start state set")
            env_record.env.grid = start_state
            #print("Start state retrieved")
        if render:
        #print("Number of obstacles: " + str(len(env_record.env.obstacles)))
        observation = env_record.env.gen_obs()
        #print("Got observation")
#        print(observation)
        observation = torch.tensor(observation['image']).type('torch.FloatTensor').view(147, 1).to(top_dev)
        print("Observation shape: " + str(observation.shape))
        #print("Observation formatted")
        observation.requires_grad = True
        #observation =
        act = epsilon_greedy(model, epsilon, observation, n_actions, t, rng, check=True)
        #print("Action index: " + str(act))
##        print("Got action")
        action = act_range[act]
        #print("Action selection: " + str(action))
        while t < max_steps:
            print("Episode: " + str(episode) + " , " + "Timestep: " + str(t))
            if render:
            #print("Time Step: " + str(t))
            t += 1
            target_update_counter += 1
            seedVal = int(time.time())
            next_obs, reward, done, _, info = env_record.step(action)
            print("Step reward: " + str(reward))
            #print("Step reward: " + str(reward))
            next_obs = torch.tensor(next_obs).type('torch.FloatTensor').view(147, 1).to(top_dev)
            next_obs.requires_grad = True
            #new_obs =
            total_reward += reward
            act_ = epsilon_greedy(model, epsilon, next_obs, n_actions, t, rng, check=True)
            #print("Action index: " + str(act))
            action_ = act_range[act_]
            #print("Action selection: " + str(action))
            memory.push(observation, act, reward, next_obs, done)
            if len(memory) > batch_size and done:
                batch_sampled = memory.sample(batch_size)
##                batch = Transition(*zip(*transitions))
##                non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool)
##                non_final_next_states =[s for s in batch.next_state
##                                                if s is not None])
##                state_batch =
##                action_batch =
##                reward_batch =
                Qtarget = [item.reward + (1 - int(item.done)) * gamma * torch.max(model(item.next_state)) for item in batch_sampled]
                loss = cost(model, batch_sampled, Qtarget, top_dev)
                #print("Loss Gradient Function: " + str(loss.grad_fn))
##                grads = torch.autograd.grad(loss, list(model.parameters()), allow_unused=True)
##                print("Grad type: " + str(grads))
                #print("Loss: " + str(loss))

##                loss.backward()
##                for param in model.parameters():
##                    print("Parameter gradient: " + str(param.grad))
                #print("Optimization step")
##            current_replay_memory = memory.output_all()
##            current_target_for_replay_memory = [item.reward + (1 - int(item.done)) * gamma * torch.max(model(item.next_state)) for item in current_replay_memory]
            if target_update_counter >= target_update:
                target_update_counter = 0
            observation, action = next_obs, action_
            if done or t == max_steps:
                epsilon = epsilon / ((episode / 300) + 1)
                alpha = 0.95 * alpha
                print("Reward data length: " + str(len(timestep_reward)))
    stop_time = time.asctime()
    print("Start time: ")
    print("Stop time: ")
    print(stop_time), param_file)
    return timestep_reward, iter_index, iter_reward, iter_total_steps

def test_agent(model, env_folder, epsilon, env_name, config_name, n_tests, max_steps, delay=1):
    act_range = [0, 1, 2, 6]
    n_successes = 0
    test_rewards = []
##    use_cuda = torch.cuda.is_available()
##    main_device = torch.device("cuda" if use_cuda else "cpu")
##    circuit =
    env = gym.make(env_name, max_episode_steps=max_steps, render_mode="human", height=64, width=64)
    env = SymbolicObsWrapper(env)
    env = ImgObsFlatWrapper(env)
    env_record = RecordVideo(env, f"video/TTNMinigridTraining")
    done = False
    for test in range(n_tests):
        reward_total = 0
        epsilon = 0
        filename = env_folder + "/" + env_name + "_" + str(test)
        statefile = open(filename, "wb")
        state_data = pkl.load(statefile)
        new_grid = env_record.env.grid.decode(state_data)
        env_record.env.grid = new_grid
        while True:
            s = torch.tensor(observation).type('torch.FloatTensor').view(1, -1)
            act = epsilon_greedy(model, epsilon, observation)
            a = act_range[act]
            next_obs, reward, done, info = env_record.step(a)
            next_obs = torch.tensor(next_obs).type('torch.FloatTensor').view(1, -1)
            reward_total += reward
            if done:
                if reward > 0:
                    n_successes += 1
                    print("Goal Reached")
                    print("Task Failed")
    return test_rewards, n_sucesses

def main():
##    register(
##        id="Minigrid-RandomLava-6Spots-v0",
##        entry_point="RandomLavaMinigrid:RandomLavaEnv",
##        kwargs={"size": 8, "n_obstacles": 6})
    #env_name = "Minigrid-RandomLava-6Spots-v0"
    env_name = "MiniGrid-Empty-8x8-v0"
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    torch.backends.cudnn.benchmark = False
    #device = torch.device("cpu")
    alpha = 0.4
    gamma = 0.5
    epsilon = 1
    episodes = 1000
    max_steps = 100
    n_actions = 4
    seed = 42
    model = TreeTensorAgent(147).to(device)
    print("Model graph: " + str(model.graph))

    tn_opt = optim.Adam(model.parameters(), lr=5e-3, weight_decay=1e-4)
    scheduler = CosineAnnealingLR(tn_opt, T_max=episodes)
    timestep_reward, iter_index, iter_reward, iter_total_steps = ttn_train(env_name, model, alpha, gamma, epsilon, episodes, max_steps, n_actions, device, tn_opt, scheduler)
    x_vals = np.arange(episodes)
    y_vals = np.asarray(timestep_reward)
    fig, ax = plt.subplots()
    ax.plot(x_vals, y_vals)
    ax.set(xlabel="Episode", ylabel="Total Score", title="Deep Quantum TTN Learning Training Process: 6-Site Random Lava Minigrid")

if __name__ == "__main__":