mit-han-lab / torchquantum

A PyTorch-based framework for Quantum Classical Simulation, Quantum Machine Learning, Quantum Neural Networks, Parameterized Quantum Circuits with support for easy deployments on real quantum computers.
https://torchquantum.org
MIT License
1.33k stars 203 forks source link

Graph attribute of model shows up as Nonetype #97

Open JustinS6626 opened 1 year ago

JustinS6626 commented 1 year ago

I have written a QuantumModule with the following properties, showing only the constructor:

class QPTModel(tq.QuantumModule):
    class QLayer(tq.QuantumModule):
        def __init__(self):
            super().__init__()
            self.n_wires = 8
            self.n_actions = 4
            self.rz_0_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_0_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_0_0 = tq.RY(has_params=True, trainable=True)
            self.ry_0_1 = tq.RY(has_params=True, trainable=True)
            self.rz_0_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_0_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_1_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_1_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_1_0 = tq.RY(has_params=True, trainable=True)
            self.ry_1_1 = tq.RY(has_params=True, trainable=True)
            self.rz_1_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_1_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_2_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_2_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_2_0 = tq.RY(has_params=True, trainable=True)
            self.ry_2_1 = tq.RY(has_params=True, trainable=True)
            self.rz_2_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_2_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_3_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_3_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_3_0 = tq.RY(has_params=True, trainable=True)
            self.ry_3_1 = tq.RY(has_params=True, trainable=True)
            self.rz_3_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_3_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_4_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_4_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_4_0 = tq.RY(has_params=True, trainable=True)
            self.ry_4_1 = tq.RY(has_params=True, trainable=True)
            self.rz_4_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_4_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_5_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_5_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_5_0 = tq.RY(has_params=True, trainable=True)
            self.ry_5_1 = tq.RY(has_params=True, trainable=True)
            self.rz_5_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_5_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_6_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_6_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_6_0 = tq.RY(has_params=True, trainable=True)
            self.ry_6_1 = tq.RY(has_params=True, trainable=True)
            self.rz_6_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_6_3 = tq.RZ(has_params=True, trainable=True)
            self.cnot = tq.CNOT(has_params=False, trainable=False)

    def __init__(self, input_size):
        super().__init__()
        self.n_wires = 8
        self.n_actions = 4
        self.input_size = input_size
        self.q_layer = self.QLayer()
        self.smx = nn.Softmax()
        self.q_device = tq.QuantumDevice(n_wires=self.n_wires)
        self.layer_1 = nn.Linear(1, 64)
        self.layer_2 = nn.ReLU()
        self.layer_3 = nn.Conv1d(self.input_size, 64, kernel_size=2, stride=2)
        self.layer_4 = nn.ReLU()
        self.layer_5 = nn.Conv1d(64, 1, kernel_size=2, stride=2)
        self.layer_6 = nn.Tanh()

        self.encoder=tq.GeneralEncoder(
            [{"input_idx" : [0], "func" : "ry", "wires" : [0]},
             {"input_idx" : [1], "func" : "ry", "wires" : [1]},
             {"input_idx" : [2], "func" : "ry", "wires" : [2]},
             {"input_idx" : [3], "func" : "ry", "wires" : [3]},
             {"input_idx" : [4], "func" : "ry", "wires" : [4]},
             {"input_idx" : [5], "func" : "ry", "wires" : [5]},
             {"input_idx" : [6], "func" : "ry", "wires" : [6]},
             {"input_idx" : [7], "func" : "ry", "wires" : [7]},
             {"input_idx" : [8], "func" : "ry", "wires" : [0]},
             {"input_idx" : [9], "func" : "ry", "wires" : [1]},
             {"input_idx" : [10], "func" : "ry", "wires" : [2]},
             {"input_idx" : [11], "func" : "ry", "wires" : [3]},
             {"input_idx" : [12], "func" : "ry", "wires" : [4]},
             {"input_idx" : [13], "func" : "ry", "wires" : [5]},
             {"input_idx" : [14], "func" : "ry", "wires" : [6]},
             {"input_idx" : [15], "func" : "ry", "wires" : [7]}])

If I create an instance of this model through model = QPTModel(100)

and I print out the attribute model.graph, I get an output of None.

I am not sure if this is supposed to be the default after the model when the class is first instantiated, but if not, I am wondering how to fix it. In this class, aside from the single numerical values, every attribute and function call is part of either torchquantum or standard Pytorch, so there is nothing that should interfere with normal Pytorch functionality. Note that the class instantiation happens in a different source file from the one in which the class is defined. Could that be a problem?

Hanrui-Wang commented 1 year ago

Hi DarthMalloc,

  1. The graph attribute is preserved for backward compatibility. The graph is used to record the operations of a module when the static mode of tq.QuantumModule is activated. In the latest simplified implementation, all operations are instead recorded in the op_list of tq.QuantumDevice, instead of using the graph of tq.QuantumModule. So don't worry about the Nonetype of graph.

  2. "Note that the class instantiation happens in a different source file from the one in which the class is defined." which class are you referring to here?

JustinS6626 commented 1 year ago

Thanks, I think that partly explains the issue that I am having. Basically what's happening is that the parameters of the QuantumModule that I have written are not getting updated at the optimization stage. After the loss.backward() call has been made, I checked the param.grad attribute for the parameters of my model, and they are all showing as a Nonetype. From what you are saying, it sounds like this might have something to do with the way I am using tq.QuantumDevice. Could that be it? For the class instantiation, the QPTModel class was defined in one source file, but the constructor was called in another. I was wondering if that might have been part of the problem.

Hanrui-Wang commented 1 year ago

Hi DarthMalloc,

  1. Just to confirm, you are trying to get the gradient of parameters in the QuantumModule instead of the graph of the QuantumModule, right? I think the gradients should be correctly calculated after loss.backward(). Could you provide more details? Below is an example of gradient computation:
import torchquantum as tq
import torchquantum.functional as tqf

qdev = tq.QuantumDevice(n_wires=2, bsz=5, device="cpu", record_op=True) # use device='cuda' for GPU

op = tq.RX(has_params=True, trainable=True, init_params=0.5)
op(qdev, wires=0)

# obtain the expval on a observable
from torchquantum.measurement import expval_joint_analytical
expval = expval_joint_analytical(qdev, 'ZX')
print(expval)

# obtain gradients of expval w.r.t. trainable parameters
expval[0].backward()
print(op.params.grad)
  1. I think defining QPTModel in one file and instantiate an object in another file should work.
JustinS6626 commented 1 year ago

Thanks for getting back to me so soon! Basically, if the example above was having the same problem that I have seen with my model, the print(ops.params.grad) line at the end of the example would print out None. So basically, my model is standing on its own and is not part of the order of processes.

Hanrui-Wang commented 1 year ago

Thanks for getting back to me so soon! Basically, if the example above was having the same problem that I have seen with my model, the print(ops.params.grad) line at the end of the example would print out None. So basically, my model is standing on its own and is not part of the order of processes.

could you provide the script for me to reproduce the bug?

JustinS6626 commented 1 year ago

Yes, I will post that soon. Since the code is part of the experiment for my PhD dissertation, I need to ask my supervisor first, but I should have that sorted out shortly.

JustinS6626 commented 1 year ago

Here is the code that should reproduce the problem that I have been having. In order to run it, you will need to have the Farama Gymnasium and Minigrid packages installed.

import math
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn as nn
import torch.nn.functional as F
import time
import datetime
import calendar
import random
from minigrid.wrappers import *
from gymnasium.wrappers.flatten_observation import FlattenObservation
import logging

from torchtest import assert_vars_change
import torchquantum as tq
import torchquantum.functional as tqf
from torchquantum.measurement import *

import pickle as pkl
import gymnasium as gym
from gymnasium.wrappers.record_video import RecordVideo
from collections import namedtuple, deque
#from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistic
import numpy as np
import os
from gymnasium.envs.registration import *

class ReplayMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def output_all(self):
        return self.memory

    def __len__(self):
        return len(self.memory)

Transition = namedtuple('Transition',
                        ('state', 'action', 'reward', 'next_state', 'done'))

class TreeTensorAgent(tq.QuantumModule):
    class QLayer(tq.QuantumModule):
        def __init__(self):
            super().__init__()
            self.n_wires = 8
            self.n_actions = 4
##                self.q_device = tq.QuantumDevice(n_wires=self.n_wires)

            #self.bias = torch.tensor(np.random.rand(4), requires_grad=True)
            self.rz_0_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_0_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_0_0 = tq.RY(has_params=True, trainable=True)
            self.ry_0_1 = tq.RY(has_params=True, trainable=True)
            self.rz_0_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_0_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_1_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_1_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_1_0 = tq.RY(has_params=True, trainable=True)
            self.ry_1_1 = tq.RY(has_params=True, trainable=True)
            self.rz_1_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_1_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_2_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_2_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_2_0 = tq.RY(has_params=True, trainable=True)
            self.ry_2_1 = tq.RY(has_params=True, trainable=True)
            self.rz_2_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_2_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_3_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_3_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_3_0 = tq.RY(has_params=True, trainable=True)
            self.ry_3_1 = tq.RY(has_params=True, trainable=True)
            self.rz_3_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_3_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_4_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_4_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_4_0 = tq.RY(has_params=True, trainable=True)
            self.ry_4_1 = tq.RY(has_params=True, trainable=True)
            self.rz_4_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_4_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_5_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_5_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_5_0 = tq.RY(has_params=True, trainable=True)
            self.ry_5_1 = tq.RY(has_params=True, trainable=True)
            self.rz_5_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_5_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_6_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_6_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_6_0 = tq.RY(has_params=True, trainable=True)
            self.ry_6_1 = tq.RY(has_params=True, trainable=True)
            self.rz_6_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_6_3 = tq.RZ(has_params=True, trainable=True)
            self.cnot = tq.CNOT(has_params=False, trainable=False)

        def forward(self, q_device, static_mode, graph):
            self.q_device = q_device
            #SO4(self.q_device, [self.ry_0_0, self.ry_0_1], [self.rz_0_0, self.rz_0_1, self.rz_0_2, self.rz_0_3], self.cnot, [0, 1], static=static_mode_mode, parent_graph=graph)
            #Layer 1 Gate 1 Start
            tqf.rz(q_device, wires=0, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=1, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.ry(q_device, wires=1, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.cnot(q_device, wires=[1, 0], static=static_mode)
            self.rz_0_0(q_device, wires=0)
            self.rz_0_1(q_device, wires=1)
            self.ry_0_0(q_device, wires=0)
            self.ry_0_1(q_device, wires=1)
            self.rz_0_2(q_device, wires=0)
            self.rz_0_3(q_device, wires=1)
            tqf.cnot(q_device, wires=[1, 0], static=static_mode)
            tqf.ry(q_device, wires=1, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=0, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=1, params=torch.tensor([-np.pi / 2]), static=static_mode)
            #Layer 1 Gate 1 End

            #SO4(self.q_device, [self.ry_1_0, self.ry_1_1], [self.rz_1_0, self.rz_1_1, self.rz_1_2, self.rz_1_3], self.cnot, [2, 3], static=static_mode_mode, parent_graph=graph)
            #Layer 1 Gate 2 Start
            tqf.rz(q_device, wires=2, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=3, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.ry(q_device, wires=3, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.cnot(q_device, wires=[3, 2], static=static_mode)
            self.rz_1_0(q_device, wires=2)
            self.rz_1_1(q_device, wires=3)
            self.ry_1_0(q_device, wires=2)
            self.ry_1_1(q_device, wires=3)
            self.rz_1_2(q_device, wires=2)
            self.rz_1_3(q_device, wires=3)
            tqf.cnot(q_device, wires=[3, 2], static=static_mode)
            tqf.ry(q_device, wires=3, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=2, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=3, params=torch.tensor([-np.pi / 2]), static=static_mode)
            #Layer 1 Gate 2 End

            #SO4(self.q_device, [self.ry_2_0, self.ry_2_1], [self.rz_2_0, self.rz_2_1, self.rz_2_2, self.rz_2_3], self.cnot, [4, 5], static=static_mode_mode, parent_graph=graph)
            #Layer 1 Gate 3 Start
            tqf.rz(q_device, wires=4, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=5, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.ry(q_device, wires=5, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.cnot(q_device, wires=[5, 4], static=static_mode)
            self.rz_2_0(q_device, wires=4)
            self.rz_2_1(q_device, wires=5)
            self.ry_2_0(q_device, wires=4)
            self.ry_2_1(q_device, wires=5)
            self.rz_2_2(q_device, wires=4)
            self.rz_2_3(q_device, wires=5)
            tqf.cnot(q_device, wires=[5, 4], static=static_mode)
            tqf.ry(q_device, wires=5, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=4, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=5, params=torch.tensor([-np.pi / 2]), static=static_mode)
            #Layer 1 Gate 3 End

            #SO4(self.q_device, [self.ry_3_0, self.ry_3_1], [self.rz_3_0, self.rz_3_1, self.rz_3_2, self.rz_3_3], self.cnot, [6, 7], static=static_mode_mode, parent_graph=graph)
            #Layer 1 Gate 4 Start
            tqf.rz(q_device, wires=6, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=7, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.ry(q_device, wires=7, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.cnot(q_device, wires=[7, 6], static=static_mode)
            self.rz_3_0(q_device, wires=6)
            self.rz_3_1(q_device, wires=7)
            self.ry_3_0(q_device, wires=6)
            self.ry_3_1(q_device, wires=7)
            self.rz_3_2(q_device, wires=6)
            self.rz_3_3(q_device, wires=7)
            tqf.cnot(q_device, wires=[7, 6], static=static_mode)
            tqf.ry(q_device, wires=7, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=6, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=7, params=torch.tensor([-np.pi / 2]), static=static_mode)
            #Layer 1 Gate 4 End

            #SO4(self.q_device, [self.ry_4_0, self.ry_4_1], [self.rz_4_0, self.rz_4_1, self.rz_4_2, self.rz_4_3], self.cnot, [1, 2], static=static_mode_mode, parent_graph=graph)
            #Layer 2 Gate 1 Start
            tqf.rz(q_device, wires=1, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=2, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.ry(q_device, wires=2, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.cnot(q_device, wires=[2, 1], static=static_mode)
            self.rz_4_0(q_device, wires=1)
            self.rz_4_1(q_device, wires=2)
            self.ry_4_0(q_device, wires=1)
            self.ry_4_1(q_device, wires=2)
            self.rz_4_2(q_device, wires=1)
            self.rz_4_3(q_device, wires=2)
            tqf.cnot(q_device, wires=[2, 1], static=static_mode)
            tqf.ry(q_device, wires=2, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=1, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=2, params=torch.tensor([-np.pi / 2]), static=static_mode)
            #Layer 2 Gate 1 End

            #SO4(self.q_device, [self.ry_5_0, self.ry_5_1], [self.rz_5_0, self.rz_5_1, self.rz_5_2, self.rz_5_3], self.cnot, [5, 6], static=static_mode_mode, parent_graph=graph)
            #Layer 2 Gate 2 Start
            tqf.rz(q_device, wires=5, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=6, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.ry(q_device, wires=6, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.cnot(q_device, wires=[6, 5], static=static_mode)
            self.rz_5_0(q_device, wires=5)
            self.rz_5_1(q_device, wires=6)
            self.ry_5_0(q_device, wires=5)
            self.ry_5_1(q_device, wires=6)
            self.rz_5_2(q_device, wires=5)
            self.rz_5_3(q_device, wires=6)
            tqf.cnot(q_device, wires=[6, 5], static=static_mode)
            tqf.ry(q_device, wires=6, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=5, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=6, params=torch.tensor([-np.pi / 2]), static=static_mode)
            #Layer 2 Gate 2 End
            #SO4(self.q_device, [self.ry_6_0, self.ry_6_1], [self.rz_6_0, self.rz_6_1, self.rz_6_2, self.rz_6_3], self.cnot, [2, 5], static=static_mode_mode, parent_graph=graph)
            #Layer 3 Gate 1 Start
            tqf.rz(q_device, wires=2, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=5, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.ry(q_device, wires=5, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.cnot(q_device, wires=[5, 2], static=static_mode)
            self.rz_6_0(q_device, wires=0)
            self.rz_6_1(q_device, wires=1)
            self.ry_6_0(q_device, wires=0)
            self.ry_6_1(q_device, wires=1)
            self.rz_6_2(q_device, wires=0)
            self.rz_6_3(q_device, wires=1)
            tqf.cnot(q_device, wires=[5, 2], static=static_mode, parent_graph=graph)
            tqf.ry(q_device, wires=5, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=2, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=5, params=torch.tensor([-np.pi / 2]), static=static_mode)
            #Layer 3 Gate 1 End
    def __init__(self, input_size):
        super().__init__()
        self.n_wires = 8
        self.n_actions = 4
        self.input_size = input_size
        self.q_layer = self.QLayer()
        self.smx = nn.Softmax()
        self.q_device = tq.QuantumDevice(n_wires=self.n_wires)

        self.layer_1 = nn.Linear(1, 64)
        self.layer_2 = nn.ReLU()
        self.layer_3 = nn.Conv1d(self.input_size, 64, kernel_size=2, stride=2)
        self.layer_4 = nn.ReLU()
        self.layer_5 = nn.Conv1d(64, 1, kernel_size=2, stride=2)
        self.layer_6 = nn.Tanh()

        self.encoder=tq.GeneralEncoder(
            [{"input_idx" : [0], "func" : "ry", "wires" : [0]},
             {"input_idx" : [1], "func" : "ry", "wires" : [1]},
             {"input_idx" : [2], "func" : "ry", "wires" : [2]},
             {"input_idx" : [3], "func" : "ry", "wires" : [3]},
             {"input_idx" : [4], "func" : "ry", "wires" : [4]},
             {"input_idx" : [5], "func" : "ry", "wires" : [5]},
             {"input_idx" : [6], "func" : "ry", "wires" : [6]},
             {"input_idx" : [7], "func" : "ry", "wires" : [7]},
             {"input_idx" : [8], "func" : "ry", "wires" : [0]},
             {"input_idx" : [9], "func" : "ry", "wires" : [1]},
             {"input_idx" : [10], "func" : "ry", "wires" : [2]},
             {"input_idx" : [11], "func" : "ry", "wires" : [3]},
             {"input_idx" : [12], "func" : "ry", "wires" : [4]},
             {"input_idx" : [13], "func" : "ry", "wires" : [5]},
             {"input_idx" : [14], "func" : "ry", "wires" : [6]},
             {"input_idx" : [15], "func" : "ry", "wires" : [7]}])

##    def get_angles_atan(self, in_x):
##        angles = torch.stack([torch.stack([torch.atan(item), torch.atan(item**2)]) for item in in_x])
##        return angles

    def forward(self, input_data, check=False):

        x_1 = self.layer_1(input_data)
        x_2 = self.layer_2(x_1)

        x_3 = self.layer_3(x_2)
        x_4 = self.layer_4(x_3)

        x_5 = self.layer_5(x_4)
        x_6 = self.layer_6(x_5)

        x_angles = torch.atan(x_6)
        for i in range(self.n_wires):

            tqf.hadamard(self.q_device, wires=i, static=self.static_mode, parent_graph=self.graph)
        self.encoder(self.q_device, x_angles)
        self.q_layer.forward(self.q_device, self.static_mode, self.graph)
        obs_1 = expval_joint_analytical(self.q_device, "ZZZZZZZZ")
        obs_2 = expval_joint_analytical(self.q_device, "ZZZYZZZZ")
        obs_3 = expval_joint_analytical(self.q_device, "ZZZZYZZZ")
        obs_4 = expval_joint_analytical(self.q_device, "ZZZYYZZZ")
        expectations = torch.stack([obs_1, obs_2, obs_3, obs_4], dim=1)
        measure_weights = self.smx(expectations)

        measure_weights = measure_weights.view(4)
        return measure_weights

def square_loss(labels, predictions):
    loss = 0
    for l, p in zip(labels, predictions):
        loss = loss + ((l - p) ** 2)
    loss = loss / len(labels)
    return loss

def epsilon_greedy(TreeTensor, epsilon, s, n_actions, check=False, train=False):
    if train or np.random.rand() < (1 - epsilon):
        with torch.no_grad():
            measurements = TreeTensor(s, check=check)
            action = torch.argmax(measurements)
            return action
        if check:
            print("Argmax result: " + str(action))
    else:
        choices = np.random.randint(0, n_actions, size=25)

        action = np.bincount(choices).argmax()

        action = torch.tensor(action)
        return action

def cost(model, features, labels, dev):
    loss_func = nn.SmoothL1Loss()
    predictions = [model(item.state)[item.action] for item in features]
    loss_total = loss_func(torch.tensor(labels, requires_grad=True, device=dev), torch.tensor(predictions, requires_grad=True, device=dev))
    return loss_total

def ttn_train(env_name, model, alpha, gamma, epsilon, episodes, max_steps, n_actions, top_dev, opt, sched, render=True):

    act_range = [0, 1, 2, 6]
    logging.basicConfig(filename="ExperimentDebug1.txt", level=logging.DEBUG)
    logging.captureWarnings(True)

    scores = []
    target_update = 20
    batch_size = 5
    optimize_steps = 5
    target_update_counter = 0
    iter_index = []
    iter_reward = []
    iter_total_steps = []
    cost_list = []
    timestep_reward = []
    memory = ReplayMemory(80)
    seedVal = int(time.time())
    np.random.seed(seedVal)
    env = gym.make(env_name, render_mode="human")

    env_record = RecordVideo(env, f"video/TTNMinigridTraining")
    start_state = None
    for episode in range(episodes):
        env_record.reset()
        print("Episode: " + str(episode))
        t = 0
        total_reward = 0

        done = False

        seedVal = int(time.time())
        np.random.seed(seedVal)

        if episode == 0:
            start_state = env_record.env.grid
        else:
            env_record.env.grid = start_state

        if render:
            env_record.render()

        observation = env_record.env.gen_obs()
        print(observation['image'].shape)
        observation = torch.tensor(observation['image']).type('torch.FloatTensor').view(147, 1).to(top_dev)
        print("Observation shape: " + str(observation.shape))
        observation.requires_grad = True

        act = epsilon_greedy(model, epsilon, observation, n_actions, check=False)
        action = act_range[act]
        scores.append(total_reward)
        while t < max_steps:
            print("Episode: " + str(episode) + " , " + "Timestep: " + str(t))
            if render:
                env_record.render()
            t += 1
            target_update_counter += 1
            seedVal = int(time.time())
            np.random.seed(seedVal)
            next_obs, reward, done, _, info = env_record.step(action)
            #print(next_obs['image'].shape)
            next_obs = torch.tensor(next_obs['image']).type('torch.FloatTensor').view(147, 1).to(top_dev)
            next_obs.requires_grad = True
            total_reward += reward
            act_ = epsilon_greedy(model, epsilon, next_obs, n_actions, check=True)

            action_ = act_range[act_]
            memory.push(observation, act, reward, next_obs, done)
            if len(memory) > batch_size:
                batch_sampled = memory.sample(batch_size)

                Qtarget = [item.reward + (1 - int(item.done)) * gamma * torch.max(model(item.next_state)) for item in batch_sampled]
                loss = cost(model, batch_sampled, Qtarget, top_dev)
                for param in model.parameters():
                    print("Parameter gradient: " + str(param.grad))
                opt.zero_grad()

                opt.step()
            current_replay_memory = memory.output_all()
            current_target_for_replay_memory = [item.reward + (1 - int(item.done)) * gamma * epsilon_greedy(model, epsilon, item.next_state, n_actions) for item in current_replay_memory]
            if target_update_counter >= target_update:
                target_update_counter = 0
            observation, action = next_obs, action_
            if done:
                epsilon = epsilon / ((episode / 1000) + 1)
                timestep_reward.append(total_reward)
                iter_index.append(episode)
                iter_total_steps.append(t)
                break
    torch.save(model.state_dict(), param_file1)
    torch.save(model.mps.state_dict, param_file2)
    return timestep_reward, iter_index, iter_reward, iter_total_steps

def main():
    env_name = "MiniGrid-Empty-8x8-v0"
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    alpha = 0.4
    gamma = 0.999
    epsilon = 1
    episodes = 1000
    max_steps = 100
    n_actions = 4
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    model = TreeTensorAgent(147).to(device)
    tn_opt = optim.Adam(model.parameters(), lr=5e-3, weight_decay=1e-4)
    scheduler = CosineAnnealingLR(tn_opt, T_max=episodes)
    timestep_reward, iter_index, iter_reward, iter_total_steps = ttn_train(env_name, model, alpha, gamma, epsilon, episodes, max_steps, n_actions, device, tn_opt, scheduler)
    x_vals = np.arange(episodes)
    y_vals = np.asarray(timestep_reward)
    fig, ax = plt.subplots()
    ax.plot(x_vals, y_vals)
    ax.grid()
    ax.set(xlabel="Episode", ylabel="Total Score", title="Deep Quantum TTN Learning Training Process: 6-Site Random Lava Minigrid")
    fig.savefig("TTNTrain.png")
    plt.close(fig)

if __name__ == "__main__":
    main()
Hanrui-Wang commented 1 year ago

Hi DarthMalloc,

could you try adding loss.backward() between opt.zero_grad() and opt.step() ?

JustinS6626 commented 1 year ago

Thanks for getting back to me! Sorry about the type in my upload. loss.backward() is actually included in that spot in the original file, but somehow it did not get copied into the code that I posted on Monday.

JustinS6626 commented 1 year ago

The problem is that the parameters of the model have no gradient even after loss.backward() is called.

Hanrui-Wang commented 1 year ago

Hi DarthMalloc,

Please find the code below that can obtain the gradient correctly, I modified several parts which may not reflect your initial meaning of the model.

There are two major issues:

  1. In the cost function, creating a new tensor using torch.tensor will lose the gradient tracking. so I modified as below:
def cost(model, features, labels, dev):
    loss_func = nn.SmoothL1Loss()
    predictions = [model(item.state)[item.action] for item in features]
    # loss_total = loss_func(torch.tensor(labels, requires_grad=True, device=dev), torch.tensor(predictions, requires_grad=True, device=dev))
    loss_total = loss_func(torch.stack(labels), torch.stack(predictions))
    return loss_total
  1. for tq.QuantumDevice, you can either create a new qdev every time running the forward function; or you can call the qdev.reset_state(bsz=10) every time running the forward function.

Here is the full modified code:

import math
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn as nn
import torch.nn.functional as F
import time
import datetime
import calendar
import random
from minigrid.wrappers import *
import logging
import matplotlib.pyplot as plt
import torchquantum as tq
import torchquantum.functional as tqf
from torchquantum.measurement import *

import pickle as pkl
import gymnasium as gym
from gymnasium.wrappers.record_video import RecordVideo
from collections import namedtuple, deque
#from gymnasium.wrappers.record_episode_statistics import RecordEpisodeStatistic
import numpy as np
import os
from gymnasium.envs.registration import *

import os
os.environ["SDL_VIDEODRIVER"] = "dummy"

class ReplayMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def output_all(self):
        return self.memory

    def __len__(self):
        return len(self.memory)

Transition = namedtuple('Transition',
                        ('state', 'action', 'reward', 'next_state', 'done'))

class TreeTensorAgent(tq.QuantumModule):
    class QLayer(tq.QuantumModule):
        def __init__(self):
            super().__init__()
            self.n_wires = 8
            self.n_actions = 4
##                self.q_device = tq.QuantumDevice(n_wires=self.n_wires)

            #self.bias = torch.tensor(np.random.rand(4), requires_grad=True)
            self.rz_0_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_0_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_0_0 = tq.RY(has_params=True, trainable=True)
            self.ry_0_1 = tq.RY(has_params=True, trainable=True)
            self.rz_0_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_0_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_1_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_1_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_1_0 = tq.RY(has_params=True, trainable=True)
            self.ry_1_1 = tq.RY(has_params=True, trainable=True)
            self.rz_1_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_1_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_2_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_2_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_2_0 = tq.RY(has_params=True, trainable=True)
            self.ry_2_1 = tq.RY(has_params=True, trainable=True)
            self.rz_2_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_2_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_3_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_3_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_3_0 = tq.RY(has_params=True, trainable=True)
            self.ry_3_1 = tq.RY(has_params=True, trainable=True)
            self.rz_3_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_3_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_4_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_4_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_4_0 = tq.RY(has_params=True, trainable=True)
            self.ry_4_1 = tq.RY(has_params=True, trainable=True)
            self.rz_4_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_4_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_5_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_5_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_5_0 = tq.RY(has_params=True, trainable=True)
            self.ry_5_1 = tq.RY(has_params=True, trainable=True)
            self.rz_5_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_5_3 = tq.RZ(has_params=True, trainable=True)
            self.rz_6_0 = tq.RZ(has_params=True, trainable=True)
            self.rz_6_1 = tq.RZ(has_params=True, trainable=True)
            self.ry_6_0 = tq.RY(has_params=True, trainable=True)
            self.ry_6_1 = tq.RY(has_params=True, trainable=True)
            self.rz_6_2 = tq.RZ(has_params=True, trainable=True)
            self.rz_6_3 = tq.RZ(has_params=True, trainable=True)
            self.cnot = tq.CNOT(has_params=False, trainable=False)

        def forward(self, q_device, static_mode, graph):
            # self.q_device = q_device
            #SO4(self.q_device, [self.ry_0_0, self.ry_0_1], [self.rz_0_0, self.rz_0_1, self.rz_0_2, self.rz_0_3], self.cnot, [0, 1], static=static_mode_mode, parent_graph=graph)
            #Layer 1 Gate 1 Start
            tqf.rz(q_device, wires=0, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=1, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.ry(q_device, wires=1, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.cnot(q_device, wires=[1, 0], static=static_mode)
            self.rz_0_0(q_device, wires=0)
            self.rz_0_1(q_device, wires=1)
            self.ry_0_0(q_device, wires=0)
            self.ry_0_1(q_device, wires=1)
            self.rz_0_2(q_device, wires=0)
            self.rz_0_3(q_device, wires=1)
            tqf.cnot(q_device, wires=[1, 0], static=static_mode)
            tqf.ry(q_device, wires=1, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=0, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=1, params=torch.tensor([-np.pi / 2]), static=static_mode)
            #Layer 1 Gate 1 End

            #SO4(self.q_device, [self.ry_1_0, self.ry_1_1], [self.rz_1_0, self.rz_1_1, self.rz_1_2, self.rz_1_3], self.cnot, [2, 3], static=static_mode_mode, parent_graph=graph)
            #Layer 1 Gate 2 Start
            tqf.rz(q_device, wires=2, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=3, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.ry(q_device, wires=3, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.cnot(q_device, wires=[3, 2], static=static_mode)
            self.rz_1_0(q_device, wires=2)
            self.rz_1_1(q_device, wires=3)
            self.ry_1_0(q_device, wires=2)
            self.ry_1_1(q_device, wires=3)
            self.rz_1_2(q_device, wires=2)
            self.rz_1_3(q_device, wires=3)
            tqf.cnot(q_device, wires=[3, 2], static=static_mode)
            tqf.ry(q_device, wires=3, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=2, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=3, params=torch.tensor([-np.pi / 2]), static=static_mode)
            #Layer 1 Gate 2 End

            #SO4(self.q_device, [self.ry_2_0, self.ry_2_1], [self.rz_2_0, self.rz_2_1, self.rz_2_2, self.rz_2_3], self.cnot, [4, 5], static=static_mode_mode, parent_graph=graph)
            #Layer 1 Gate 3 Start
            tqf.rz(q_device, wires=4, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=5, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.ry(q_device, wires=5, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.cnot(q_device, wires=[5, 4], static=static_mode)
            self.rz_2_0(q_device, wires=4)
            self.rz_2_1(q_device, wires=5)
            self.ry_2_0(q_device, wires=4)
            self.ry_2_1(q_device, wires=5)
            self.rz_2_2(q_device, wires=4)
            self.rz_2_3(q_device, wires=5)
            tqf.cnot(q_device, wires=[5, 4], static=static_mode)
            tqf.ry(q_device, wires=5, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=4, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=5, params=torch.tensor([-np.pi / 2]), static=static_mode)
            #Layer 1 Gate 3 End

            #SO4(self.q_device, [self.ry_3_0, self.ry_3_1], [self.rz_3_0, self.rz_3_1, self.rz_3_2, self.rz_3_3], self.cnot, [6, 7], static=static_mode_mode, parent_graph=graph)
            #Layer 1 Gate 4 Start
            tqf.rz(q_device, wires=6, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=7, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.ry(q_device, wires=7, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.cnot(q_device, wires=[7, 6], static=static_mode)
            self.rz_3_0(q_device, wires=6)
            self.rz_3_1(q_device, wires=7)
            self.ry_3_0(q_device, wires=6)
            self.ry_3_1(q_device, wires=7)
            self.rz_3_2(q_device, wires=6)
            self.rz_3_3(q_device, wires=7)
            tqf.cnot(q_device, wires=[7, 6], static=static_mode)
            tqf.ry(q_device, wires=7, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=6, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=7, params=torch.tensor([-np.pi / 2]), static=static_mode)
            #Layer 1 Gate 4 End

            #SO4(self.q_device, [self.ry_4_0, self.ry_4_1], [self.rz_4_0, self.rz_4_1, self.rz_4_2, self.rz_4_3], self.cnot, [1, 2], static=static_mode_mode, parent_graph=graph)
            #Layer 2 Gate 1 Start
            tqf.rz(q_device, wires=1, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=2, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.ry(q_device, wires=2, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.cnot(q_device, wires=[2, 1], static=static_mode)
            self.rz_4_0(q_device, wires=1)
            self.rz_4_1(q_device, wires=2)
            self.ry_4_0(q_device, wires=1)
            self.ry_4_1(q_device, wires=2)
            self.rz_4_2(q_device, wires=1)
            self.rz_4_3(q_device, wires=2)
            tqf.cnot(q_device, wires=[2, 1], static=static_mode)
            tqf.ry(q_device, wires=2, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=1, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=2, params=torch.tensor([-np.pi / 2]), static=static_mode)
            #Layer 2 Gate 1 End

            #SO4(self.q_device, [self.ry_5_0, self.ry_5_1], [self.rz_5_0, self.rz_5_1, self.rz_5_2, self.rz_5_3], self.cnot, [5, 6], static=static_mode_mode, parent_graph=graph)
            #Layer 2 Gate 2 Start
            tqf.rz(q_device, wires=5, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=6, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.ry(q_device, wires=6, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.cnot(q_device, wires=[6, 5], static=static_mode)
            self.rz_5_0(q_device, wires=5)
            self.rz_5_1(q_device, wires=6)
            self.ry_5_0(q_device, wires=5)
            self.ry_5_1(q_device, wires=6)
            self.rz_5_2(q_device, wires=5)
            self.rz_5_3(q_device, wires=6)
            tqf.cnot(q_device, wires=[6, 5], static=static_mode)
            tqf.ry(q_device, wires=6, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=5, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=6, params=torch.tensor([-np.pi / 2]), static=static_mode)
            #Layer 2 Gate 2 End
            #SO4(self.q_device, [self.ry_6_0, self.ry_6_1], [self.rz_6_0, self.rz_6_1, self.rz_6_2, self.rz_6_3], self.cnot, [2, 5], static=static_mode_mode, parent_graph=graph)
            #Layer 3 Gate 1 Start
            tqf.rz(q_device, wires=2, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=5, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.ry(q_device, wires=5, params=torch.tensor([np.pi / 2]), static=static_mode)
            tqf.cnot(q_device, wires=[5, 2], static=static_mode)
            self.rz_6_0(q_device, wires=0)
            self.rz_6_1(q_device, wires=1)
            self.ry_6_0(q_device, wires=0)
            self.ry_6_1(q_device, wires=1)
            self.rz_6_2(q_device, wires=0)
            self.rz_6_3(q_device, wires=1)
            tqf.cnot(q_device, wires=[5, 2], static=static_mode, parent_graph=graph)
            tqf.ry(q_device, wires=5, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=2, params=torch.tensor([-np.pi / 2]), static=static_mode)
            tqf.rz(q_device, wires=5, params=torch.tensor([-np.pi / 2]), static=static_mode)
            #Layer 3 Gate 1 End

    def __init__(self, input_size):
        super().__init__()
        self.n_wires = 8
        self.n_actions = 4
        self.input_size = input_size
        self.q_layer = self.QLayer()
        self.smx = nn.Softmax()
        # self.q_device = tq.QuantumDevice(n_wires=self.n_wires)

        self.layer_1 = nn.Linear(1, 64)
        self.layer_2 = nn.ReLU()
        self.layer_3 = nn.Conv1d(self.input_size, 64, kernel_size=2, stride=2)
        self.layer_4 = nn.ReLU()
        self.layer_5 = nn.Conv1d(64, 1, kernel_size=2, stride=2)
        self.layer_6 = nn.Tanh()

        self.encoder=tq.GeneralEncoder(
            [{"input_idx" : [0], "func" : "ry", "wires" : [0]},
             {"input_idx" : [1], "func" : "ry", "wires" : [1]},
             {"input_idx" : [2], "func" : "ry", "wires" : [2]},
             {"input_idx" : [3], "func" : "ry", "wires" : [3]},
             {"input_idx" : [4], "func" : "ry", "wires" : [4]},
             {"input_idx" : [5], "func" : "ry", "wires" : [5]},
             {"input_idx" : [6], "func" : "ry", "wires" : [6]},
             {"input_idx" : [7], "func" : "ry", "wires" : [7]},
             {"input_idx" : [8], "func" : "ry", "wires" : [0]},
             {"input_idx" : [9], "func" : "ry", "wires" : [1]},
             {"input_idx" : [10], "func" : "ry", "wires" : [2]},
             {"input_idx" : [11], "func" : "ry", "wires" : [3]},
             {"input_idx" : [12], "func" : "ry", "wires" : [4]},
             {"input_idx" : [13], "func" : "ry", "wires" : [5]},
             {"input_idx" : [14], "func" : "ry", "wires" : [6]},
             {"input_idx" : [15], "func" : "ry", "wires" : [7]}])

##    def get_angles_atan(self, in_x):
##        angles = torch.stack([torch.stack([torch.atan(item), torch.atan(item**2)]) for item in in_x])
##        return angles

    def forward(self, input_data, check=False):
        qdev = tq.QuantumDevice(n_wires=self.n_wires, bsz=input_data.shape[0], device=input_data.device)

        x_1 = self.layer_1(input_data)
        x_2 = self.layer_2(x_1)

        x_3 = self.layer_3(x_2)
        x_4 = self.layer_4(x_3)

        x_5 = self.layer_5(x_4)
        x_6 = self.layer_6(x_5)

        x_angles = torch.atan(x_6)
        for i in range(self.n_wires):
            tqf.hadamard(qdev, wires=i, static=self.static_mode, parent_graph=self.graph)
        self.encoder(qdev, x_angles)
        self.q_layer(qdev, self.static_mode, self.graph)
        obs_1 = expval_joint_analytical(qdev, "ZZZZZZZZ")
        obs_2 = expval_joint_analytical(qdev, "ZZZYZZZZ")
        obs_3 = expval_joint_analytical(qdev, "ZZZZYZZZ")
        obs_4 = expval_joint_analytical(qdev, "ZZZYYZZZ")
        expectations = torch.stack([obs_1, obs_2, obs_3, obs_4], dim=1)
        measure_weights = self.smx(expectations)

        measure_weights = measure_weights.mean(-1)
        return measure_weights

def square_loss(labels, predictions):
    loss = 0
    for l, p in zip(labels, predictions):
        loss = loss + ((l - p) ** 2)
    loss = loss / len(labels)
    return loss

def epsilon_greedy(TreeTensor, epsilon, s, n_actions, check=False, train=False):
    if train or np.random.rand() < (1 - epsilon):
        with torch.no_grad():
            measurements = TreeTensor(s, check=check)
            action = torch.argmax(measurements)
            return action
        if check:
            print("Argmax result: " + str(action))
    else:
        choices = np.random.randint(0, n_actions, size=25)

        action = np.bincount(choices).argmax()

        action = torch.tensor(action)
        return action

def cost(model, features, labels, dev):
    loss_func = nn.SmoothL1Loss()
    predictions = [model(item.state)[item.action] for item in features]
    # loss_total = loss_func(torch.tensor(labels, requires_grad=True, device=dev), torch.tensor(predictions, requires_grad=True, device=dev))
    loss_total = loss_func(torch.stack(labels), torch.stack(predictions))
    return loss_total

def ttn_train(env_name, model, alpha, gamma, epsilon, episodes, max_steps, n_actions, top_dev, opt, sched, render=True):

    act_range = [0, 1, 2, 6]
    logging.basicConfig(filename="ExperimentDebug1.txt", level=logging.DEBUG)
    logging.captureWarnings(True)

    scores = []
    target_update = 20
    batch_size = 5
    optimize_steps = 5
    target_update_counter = 0
    iter_index = []
    iter_reward = []
    iter_total_steps = []
    cost_list = []
    timestep_reward = []
    memory = ReplayMemory(80)
    seedVal = int(time.time())
    np.random.seed(seedVal)
    env = gym.make(env_name, render_mode="human")

    env_record = RecordVideo(env, f"video/TTNMinigridTraining")
    start_state = None
    for episode in range(episodes):
        env_record.reset()
        print("Episode: " + str(episode))
        t = 0
        total_reward = 0

        done = False

        seedVal = int(time.time())
        np.random.seed(seedVal)

        if episode == 0:
            start_state = env_record.env.grid
        else:
            env_record.env.grid = start_state

        if render:
            env_record.render()

        observation = env_record.env.gen_obs()
        print(observation['image'].shape)
        observation = torch.tensor(observation['image']).type('torch.FloatTensor').view(147, 1).to(top_dev)
        print("Observation shape: " + str(observation.shape))
        observation.requires_grad = True

        act = epsilon_greedy(model, epsilon, observation, n_actions, check=False)
        action = act_range[act]
        scores.append(total_reward)
        while t < max_steps:
            print("Episode: " + str(episode) + " , " + "Timestep: " + str(t))
            if render:
                env_record.render()
            t += 1
            target_update_counter += 1
            seedVal = int(time.time())
            np.random.seed(seedVal)
            next_obs, reward, done, _, info = env_record.step(action)
            #print(next_obs['image'].shape)
            next_obs = torch.tensor(next_obs['image']).type('torch.FloatTensor').view(147, 1).to(top_dev)
            next_obs.requires_grad = True
            total_reward += reward
            act_ = epsilon_greedy(model, epsilon, next_obs, n_actions, check=True)

            action_ = act_range[act_]
            memory.push(observation, act, reward, next_obs, done)
            if len(memory) > batch_size:
                batch_sampled = memory.sample(batch_size)
                Qtarget = []

                Qtarget = [item.reward + (1 - int(item.done)) * gamma * torch.max(model(item.next_state).flatten()) for item in batch_sampled]
                loss = cost(model, batch_sampled, Qtarget, top_dev)
                # for param in model.parameters():
                    # print("Parameter gradient: " + str(param.grad))
                opt.zero_grad()
                loss.backward()
                for name, param in model.named_parameters():
                    print(f"Parameter {name} gradient: " + str(param.grad))

                opt.step()
            current_replay_memory = memory.output_all()
            current_target_for_replay_memory = [item.reward + (1 - int(item.done)) * gamma * epsilon_greedy(model, epsilon, item.next_state, n_actions) for item in current_replay_memory]
            if target_update_counter >= target_update:
                target_update_counter = 0
            observation, action = next_obs, action_
            if done:
                epsilon = epsilon / ((episode / 1000) + 1)
                timestep_reward.append(total_reward)
                iter_index.append(episode)
                iter_total_steps.append(t)
                break
    torch.save(model.state_dict(), param_file1)
    torch.save(model.mps.state_dict, param_file2)
    return timestep_reward, iter_index, iter_reward, iter_total_steps

def main():
    env_name = "MiniGrid-Empty-8x8-v0"
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    alpha = 0.4
    gamma = 0.999
    epsilon = 1
    episodes = 1000
    max_steps = 100
    n_actions = 4
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    model = TreeTensorAgent(147).to(device)
    tn_opt = optim.Adam(model.parameters(), lr=5e-3, weight_decay=1e-4)
    scheduler = CosineAnnealingLR(tn_opt, T_max=episodes)
    timestep_reward, iter_index, iter_reward, iter_total_steps = ttn_train(env_name, model, alpha, gamma, epsilon, episodes, max_steps, n_actions, device, tn_opt, scheduler)
    x_vals = np.arange(episodes)
    y_vals = np.asarray(timestep_reward)
    fig, ax = plt.subplots()
    ax.plot(x_vals, y_vals)
    ax.grid()
    ax.set(xlabel="Episode", ylabel="Total Score", title="Deep Quantum TTN Learning Training Process: 6-Site Random Lava Minigrid")
    fig.savefig("TTNTrain.png")
    plt.close(fig)

if __name__ == "__main__":
    main()
JustinS6626 commented 1 year ago

Sorry about the delay. Thank you so much for your help! The problem seems to be resolved now.