hidet-org / hidet

An open-source efficient deep learning framework/compiler, written in python.
https://hidet.org
Apache License 2.0
632 stars 50 forks source link

Question about Complex datatype support #265

Closed wcqc closed 1 year ago

wcqc commented 1 year ago

The inductor backend of Pytorch2.0 does not officially support Complex data types yet (https://github.com/pytorch/pytorch/issues/93424), just wondering if hidet has the same limitation currently, or not?

If it does does it rely on other parts of Pytorch2.0 (e.g., the inductor, dynamo etc.) to fully support complex, or can complex support be added separately?

yaoyaoding commented 1 year ago

Hi @wcqc,

Yes, we do not officially support complex data types yet.

Hidet does not reply on any pytorch components to work (e.g., you can install hidet and use its other frontends like onnx).

Can I know what networks are you using? It would be great if you could share a self-contained script to run a complex-value network and I am happy to add some basic support for that in hidet.

wcqc commented 1 year ago

Hi @yaoyaoding ,

Thank you for the swift reply! torch.compile() failed for me when I used the inductor backend of PyTorch2.0 due the missing support of complex dtype in the inductor and many other places in PyTorch2.0. The above was more of a general question on this with regards to hidet.

On the other hand more specifically, I'm trying to accelerate code based on this library: https://github.com/mit-han-lab/torchquantum, which mainly involves matrix multiplications where the matrices contain complex numbers). An example program can be found here: https://github.com/mit-han-lab/torchquantum/blob/main/examples/simple_mnist/mnist_example.py, and when I run this with torch.compile(model, backend='hidet'), the error I'm getting is:

/venvs/pt2py3.10/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 675, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: hidet_backend raised TypeError: avg_pool2d() missing 2 required positional arguments: 'stride' and 'padding', occurred when calling avg_pool2d with 
    args: (<hidet.Tensor object at 0x7ff987e95840>, 6)
  kwargs: {}
avg_pool2d is defined at
  File "/venvs/pt2py3.10/lib/python3.10/site-packages/hidet/graph/frontend/torch/register_functions.py", line 175

Set torch._dynamo.config.verbose=True for more information

You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

Now my question is: Does it look like hidet would work with a lib such as TorchQuantum above and provide acceleration (after fixing the above error)? Or as it currently stands hidet would not work with TorchQuantum due to the lack of complex dtype support?

Thanks again.

yaoyaoding commented 1 year ago

Hi @wcqc,

Sorry, I still can not reproduce the error you encountered as I do not have the backgroud of quantum machine learning/simulation and do not know how to put the torch.compile in the example.

Again, a self-contained example that can be directly run would be helpful.

In general, if the torch dynamo could extract the computation graph for us, we are happy to add the support. I expect it would not take long to make it functionally work, as long as torch dynamo can extract the graph for us and no strange operator occurred. But to achieve the good performance, we still need to add some complex-value number related schedule template (as a complex64 or complex128 number have 8 and 16 bytes which is larger than the heavily optimized data type like float32 and float16 (4 bytes and 2 bytes).

wcqc commented 1 year ago

This is a script to reproduce the above error, it can be run with: python script.py, it still requires the torchquantum library.

Is this sufficient?

import hidet
import torch
import torch.nn.functional as F
import torch.optim as optim
import argparse
import random
import numpy as np

import torchquantum as tq
from torchquantum.plugins import (
    tq2qiskit_measurement,
    qiskit_assemble_circs,
    op_history2qiskit,
    op_history2qiskit_expand_params,
)

from torchquantum.datasets import MNIST
from torch.optim.lr_scheduler import CosineAnnealingLR

class QFCModel(tq.QuantumModule):
    class QLayer(tq.QuantumModule):
        def __init__(self):
            super().__init__()
            self.n_wires = 4
            self.random_layer = tq.RandomLayer(
                n_ops=50, wires=list(range(self.n_wires))
            )

            # gates with trainable parameters
            self.rx0 = tq.RX(has_params=True, trainable=True)
            self.ry0 = tq.RY(has_params=True, trainable=True)
            self.rz0 = tq.RZ(has_params=True, trainable=True)
            self.crx0 = tq.CRX(has_params=True, trainable=True)

        def forward(self, qdev: tq.QuantumDevice):
            self.random_layer(qdev)

            # some trainable gates (instantiated ahead of time)
            self.rx0(qdev, wires=0)
            self.ry0(qdev, wires=1)
            self.rz0(qdev, wires=3)
            self.crx0(qdev, wires=[0, 2])

            # add some more non-parameterized gates (add on-the-fly)
            qdev.h(wires=3)  # type: ignore
            qdev.sx(wires=2)  # type: ignore
            qdev.cnot(wires=[3, 0])  # type: ignore
            qdev.rx(
                wires=1,
                params=torch.tensor([0.1]),
                static=self.static_mode,
                parent_graph=self.graph,
            )  # type: ignore

    def __init__(self):
        super().__init__()
        self.n_wires = 4
        self.encoder = tq.GeneralEncoder(tq.encoder_op_list_name_dict["4x4_u3rx"])

        self.q_layer = self.QLayer()
        self.measure = tq.MeasureAll(tq.PauliZ)

    def forward(self, x, use_qiskit=False):
        qdev = tq.QuantumDevice(
            n_wires=self.n_wires, bsz=x.shape[0], device=x.device, record_op=True
        )

        bsz = x.shape[0]
        x = F.avg_pool2d(x, 6).view(bsz, 16)
        devi = x.device

        if use_qiskit:
            # use qiskit to process the circuit
            # create the qiskit circuit for encoder
            self.encoder(qdev, x)
            op_history_parameterized = qdev.op_history
            qdev.reset_op_history()
            encoder_circs = op_history2qiskit_expand_params(self.n_wires, op_history_parameterized, bsz=bsz)

            # create the qiskit circuit for trainable quantum layers
            self.q_layer(qdev)
            op_history_fixed = qdev.op_history
            qdev.reset_op_history()
            q_layer_circ = op_history2qiskit(self.n_wires, op_history_fixed)

            # create the qiskit circuit for measurement
            measurement_circ = tq2qiskit_measurement(qdev, self.measure)

            # assemble the encoder, trainable quantum layers, and measurement circuits
            assembled_circs = qiskit_assemble_circs(
                encoder_circs, q_layer_circ, measurement_circ
            )

            # call the qiskit processor to process the circuit
            x0 = self.qiskit_processor.process_ready_circs(qdev, assembled_circs).to(  # type: ignore
                devi
            )
            x = x0

        else:
            # use torchquantum to process the circuit
            self.encoder(qdev, x)
            qdev.reset_op_history()
            self.q_layer(qdev)
            x = self.measure(qdev)

        x = x.reshape(bsz, 2, 2).sum(-1).squeeze()
        x = F.log_softmax(x, dim=1)

        return x

def train(dataflow, model, device, optimizer):
    for feed_dict in dataflow["train"]:
        inputs = feed_dict["image"].to(device)
        targets = feed_dict["digit"].to(device)

        outputs = model(inputs)
        loss = F.nll_loss(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f"loss: {loss.item()}", end="\r")

def valid_test(dataflow, split, model, device, qiskit=False):
    target_all = []
    output_all = []
    with torch.no_grad():
        for feed_dict in dataflow[split]:
            inputs = feed_dict["image"].to(device)
            targets = feed_dict["digit"].to(device)

            outputs = model(inputs, use_qiskit=qiskit)

            target_all.append(targets)
            output_all.append(outputs)
        target_all = torch.cat(target_all, dim=0)
        output_all = torch.cat(output_all, dim=0)

    _, indices = output_all.topk(1, dim=1)
    masks = indices.eq(target_all.view(-1, 1).expand_as(indices))
    size = target_all.shape[0]
    corrects = masks.sum().item()
    accuracy = corrects / size
    loss = F.nll_loss(output_all, target_all).item()

    print(f"{split} set accuracy: {accuracy}")
    print(f"{split} set loss: {loss}")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--static", action="store_true", help="compute with " "static mode"
    )
    parser.add_argument("--pdb", action="store_true", help="debug with pdb")
    parser.add_argument(
        "--wires-per-block", type=int, default=2, help="wires per block int static mode"
    )
    parser.add_argument(
        "--epochs", type=int, default=2, help="number of training epochs"
    )

    args = parser.parse_args()

    if args.pdb:
        import pdb

        pdb.set_trace()

    seed = 0
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    dataset = MNIST(
        root="./mnist_data",
        train_valid_split_ratio=[0.9, 0.1],
        digits_of_interest=[3, 6],
        n_test_samples=75,
    )
    dataflow = dict()

    for split in dataset:
        sampler = torch.utils.data.RandomSampler(dataset[split])
        dataflow[split] = torch.utils.data.DataLoader(
            dataset[split],
            batch_size=256,
            sampler=sampler,
            num_workers=8,
            pin_memory=True,
        )

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    _model = QFCModel().to(device)
    model = torch.compile(_model, backend='hidet')

    n_epochs = args.epochs
    optimizer = optim.Adam(model.parameters(), lr=5e-3, weight_decay=1e-4)
    scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs)

    if args.static:
        # optionally to switch to the static mode, which can bring speedup
        # on training
        model.q_layer.static_on(wires_per_block=args.wires_per_block)

    for epoch in range(1, n_epochs + 1):
        # train
        print(f"Epoch {epoch}:")
        train(dataflow, model, device, optimizer)
        print(optimizer.param_groups[0]["lr"])

        # valid
        valid_test(dataflow, "valid", model, device)
        scheduler.step()

    # test
    valid_test(dataflow, "test", model, device, qiskit=False)

    # run on Qiskit simulator and real Quantum Computers
    try:
        from qiskit import IBMQ
        from torchquantum.plugins import QiskitProcessor

        # firstly perform simulate
        print(f"\nTest with Qiskit Simulator")
        processor_simulation = QiskitProcessor(use_real_qc=False)
        model.set_qiskit_processor(processor_simulation)
        valid_test(dataflow, "test", model, device, qiskit=True)

        # then try to run on REAL QC
        backend_name = "ibmq_lima"
        print(f"\nTest on Real Quantum Computer {backend_name}")
        # Please specify your own hub group and project if you have the
        # IBMQ premium plan to access more machines.
        processor_real_qc = QiskitProcessor(
            use_real_qc=True,
            backend_name=backend_name,
            hub="ibm-q",
            group="open",
            project="main",
        )
        model.set_qiskit_processor(processor_real_qc)
        valid_test(dataflow, "test", model, device, qiskit=True)
    except ImportError:
        print(
            "Please install qiskit, create an IBM Q Experience Account and "
            "save the account token according to the instruction at "
            "'https://github.com/Qiskit/qiskit-ibmq-provider', "
            "then try again."
        )

if __name__ == "__main__":
    main()
yaoyaoding commented 1 year ago

Thanks for the script!

I will have a look when I have time and try to add the missing operator when needed.

yaoyaoding commented 1 year ago

Hi @wcqc,

I have added the missing operators in #271. But there is a bug that I can not fix from hidet side. You can run the script and get the subsequent error message. It is likely a bug of torch dynamo (I have checked the outputs of hidet backend for the received sub-graph, everything looks good). I did observe that there are a lot of fusion opportunity for this network (e.g., see the outs/graphs/graph_3, we can fuse the whole sub-graph into a single kernel).

You can also try using hidet's onnx frontend if you can export the inference task as an onnx model.

import hidet
import torch
import torch.nn.functional as F
import torch.optim as optim
import argparse
import random
import numpy as np

import torchquantum as tq
from torchquantum.plugins import (
    tq2qiskit_measurement,
    qiskit_assemble_circs,
    op_history2qiskit,
    op_history2qiskit_expand_params,
)

from torchquantum.datasets import MNIST
from torch.optim.lr_scheduler import CosineAnnealingLR

hidet.option.cache_dir('./outs/cache')
hidet.torch.dynamo_config.dump_graph_ir('./outs/graphs')
hidet.torch.dynamo_config.print_input_graph()

class QFCModel(tq.QuantumModule):
    class QLayer(tq.QuantumModule):
        def __init__(self):
            super().__init__()
            self.n_wires = 4
            self.random_layer = tq.RandomLayer(
                n_ops=50, wires=list(range(self.n_wires))
            )

            # gates with trainable parameters
            self.rx0 = tq.RX(has_params=True, trainable=True)
            self.ry0 = tq.RY(has_params=True, trainable=True)
            self.rz0 = tq.RZ(has_params=True, trainable=True)
            self.crx0 = tq.CRX(has_params=True, trainable=True)

        def forward(self, qdev: tq.QuantumDevice):
            self.random_layer(qdev)

            # some trainable gates (instantiated ahead of time)
            self.rx0(qdev, wires=0)
            self.ry0(qdev, wires=1)
            self.rz0(qdev, wires=3)
            self.crx0(qdev, wires=[0, 2])

            # add some more non-parameterized gates (add on-the-fly)
            qdev.h(wires=3)  # type: ignore
            qdev.sx(wires=2)  # type: ignore
            qdev.cnot(wires=[3, 0])  # type: ignore
            qdev.rx(
                wires=1,
                params=torch.tensor([0.1]),
                static=self.static_mode,
                parent_graph=self.graph,
            )  # type: ignore

    def __init__(self):
        super().__init__()
        self.n_wires = 4
        self.encoder = tq.GeneralEncoder(tq.encoder_op_list_name_dict["4x4_u3rx"])

        self.q_layer = self.QLayer()
        self.measure = tq.MeasureAll(tq.PauliZ)

    def forward(self, x, use_qiskit=False):
        qdev = tq.QuantumDevice(
            n_wires=self.n_wires, bsz=x.shape[0], device=x.device, record_op=True
        )

        bsz = x.shape[0]
        x = F.avg_pool2d(x, 6).view(bsz, 16)
        devi = x.device

        if use_qiskit:
            # use qiskit to process the circuit
            # create the qiskit circuit for encoder
            self.encoder(qdev, x)
            op_history_parameterized = qdev.op_history
            qdev.reset_op_history()
            encoder_circs = op_history2qiskit_expand_params(self.n_wires, op_history_parameterized, bsz=bsz)

            # create the qiskit circuit for trainable quantum layers
            self.q_layer(qdev)
            op_history_fixed = qdev.op_history
            qdev.reset_op_history()
            q_layer_circ = op_history2qiskit(self.n_wires, op_history_fixed)

            # create the qiskit circuit for measurement
            measurement_circ = tq2qiskit_measurement(qdev, self.measure)

            # assemble the encoder, trainable quantum layers, and measurement circuits
            assembled_circs = qiskit_assemble_circs(
                encoder_circs, q_layer_circ, measurement_circ
            )

            # call the qiskit processor to process the circuit
            x0 = self.qiskit_processor.process_ready_circs(qdev, assembled_circs).to(  # type: ignore
                devi
            )
            x = x0

        else:
            # use torchquantum to process the circuit
            self.encoder(qdev, x)
            qdev.reset_op_history()
            self.q_layer(qdev)
            x = self.measure(qdev)

        x = x.reshape(bsz, 2, 2).sum(-1).squeeze()
        x = F.log_softmax(x, dim=1)

        return x

def train(dataflow, model, device, optimizer):
    for feed_dict in dataflow["train"]:
        inputs = feed_dict["image"].to(device)
        targets = feed_dict["digit"].to(device)

        outputs = model(inputs)
        loss = F.nll_loss(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f"loss: {loss.item()}", end="\r")

def valid_test(dataflow, split, model, device, qiskit=False):
    target_all = []
    output_all = []
    with torch.no_grad():
        for feed_dict in dataflow[split]:
            inputs = feed_dict["image"].to(device)
            targets = feed_dict["digit"].to(device)

            outputs = model(inputs, use_qiskit=qiskit)

            target_all.append(targets)
            output_all.append(outputs)
        target_all = torch.cat(target_all, dim=0)
        output_all = torch.cat(output_all, dim=0)

    _, indices = output_all.topk(1, dim=1)
    masks = indices.eq(target_all.view(-1, 1).expand_as(indices))
    size = target_all.shape[0]
    corrects = masks.sum().item()
    accuracy = corrects / size
    loss = F.nll_loss(output_all, target_all).item()

    print(f"{split} set accuracy: {accuracy}")
    print(f"{split} set loss: {loss}")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--static", action="store_true", help="compute with " "static mode"
    )
    parser.add_argument("--pdb", action="store_true", help="debug with pdb")
    parser.add_argument(
        "--wires-per-block", type=int, default=2, help="wires per block int static mode"
    )
    parser.add_argument(
        "--epochs", type=int, default=2, help="number of training epochs"
    )

    args = parser.parse_args()

    if args.pdb:
        import pdb

        pdb.set_trace()

    seed = 0
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    dataset = MNIST(
        root="./mnist_data",
        train_valid_split_ratio=[0.9, 0.1],
        digits_of_interest=[3, 6],
        n_test_samples=75,
    )
    dataflow = dict()

    for split in dataset:
        sampler = torch.utils.data.RandomSampler(dataset[split])
        dataflow[split] = torch.utils.data.DataLoader(
            dataset[split],
            batch_size=256,
            sampler=sampler,
            num_workers=8,
            pin_memory=True,
        )

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    _model = QFCModel().to(device)
    model = torch.compile(_model, backend='hidet')

    n_epochs = args.epochs
    optimizer = optim.Adam(model.parameters(), lr=5e-3, weight_decay=1e-4)
    scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs)

    if args.static:
        # optionally to switch to the static mode, which can bring speedup
        # on training
        model.q_layer.static_on(wires_per_block=args.wires_per_block)

    for epoch in range(1, n_epochs + 1):
        # train
        print(f"Epoch {epoch}:")
        train(dataflow, model, device, optimizer)
        print(optimizer.param_groups[0]["lr"])

        # valid
        valid_test(dataflow, "valid", model, device)
        scheduler.step()

    # test
    valid_test(dataflow, "test", model, device, qiskit=False)

    # run on Qiskit simulator and real Quantum Computers
    try:
        from qiskit import IBMQ
        from torchquantum.plugins import QiskitProcessor

        # firstly perform simulate
        print(f"\nTest with Qiskit Simulator")
        processor_simulation = QiskitProcessor(use_real_qc=False)
        model.set_qiskit_processor(processor_simulation)
        valid_test(dataflow, "test", model, device, qiskit=True)

        # then try to run on REAL QC
        backend_name = "ibmq_lima"
        print(f"\nTest on Real Quantum Computer {backend_name}")
        # Please specify your own hub group and project if you have the
        # IBMQ premium plan to access more machines.
        processor_real_qc = QiskitProcessor(
            use_real_qc=True,
            backend_name=backend_name,
            hub="ibm-q",
            group="open",
            project="main",
        )
        model.set_qiskit_processor(processor_real_qc)
        valid_test(dataflow, "test", model, device, qiskit=True)
    except ImportError:
        print(
            "Please install qiskit, create an IBM Q Experience Account and "
            "save the account token according to the instruction at "
            "'https://github.com/Qiskit/qiskit-ibmq-provider', "
            "then try again."
        )

if __name__ == "__main__":
    main()
Traceback (most recent call last):
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 1194, in run_node
    return node.target(*args, **kwargs)
RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 1152, in get_fake_value
    return wrap_fake_exception(
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 808, in wrap_fake_exception
    return fn()
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 1153, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 1206, in run_node
    raise RuntimeError(
RuntimeError: Failed running call_function <built-in method tensor of type object at 0x7f82e375b500>(*([FakeTensor(FakeTensor(..., device='meta', size=(), dtype=torch.int64), cpu), FakeTensor(FakeTensor(..., device='meta', size=(), dtype=torch.int64), cpu)],), **{'dtype': torch.float32}):
The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.
(scroll up for backtrace)

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/yaoyao/repos/hidet/experiments/issue-265/main.py", line 262, in <module>
    main()
  File "/home/yaoyao/repos/hidet/experiments/issue-265/main.py", line 217, in main
    train(dataflow, model, device, optimizer)
  File "/home/yaoyao/repos/hidet/experiments/issue-265/main.py", line 122, in train
    outputs = model(inputs)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
    return fn(*args, **kwargs)
  File "/home/yaoyao/repos/hidet/experiments/issue-265/main.py", line 68, in forward
    qdev = tq.QuantumDevice(
  File "/home/yaoyao/repos/hidet/experiments/issue-265/main.py", line 106, in <graph break in forward>
    self.encoder(qdev, x)
  File "/home/yaoyao/repos/hidet/experiments/issue-265/main.py", line 108, in <graph break in forward>
    self.q_layer(qdev)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yaoyao/repos/hidet/experiments/issue-265/main.py", line 40, in forward
    self.random_layer(qdev)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torchquantum/graph.py", line 25, in forward_register_graph
    res = f(*args, **kwargs)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torchquantum/layers.py", line 240, in forward
    op(q_device)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 337, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 404, in _convert_frame
    result = inner_convert(frame, cache_size, hooks)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 104, in _fn
    return fn(*args, **kwargs)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 262, in _convert_frame_assert
    return _compile(
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
    out_code = transform_code_object(code, transform)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
    transformations(instructions, code_options)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
    tracer.run()
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
    super().run()
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
    and self.step()
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
    getattr(self, inst.opname)(inst)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 342, in wrapper
    return inner_fn(self, inst)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1014, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 474, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/variables/functions.py", line 291, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/variables/functions.py", line 259, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/variables/functions.py", line 92, in call_function
    return tx.inline_user_function_return(
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 510, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1806, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1862, in inline_call_
    tracer.run()
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
    and self.step()
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
    getattr(self, inst.opname)(inst)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 342, in wrapper
    return inner_fn(self, inst)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1014, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 474, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/variables/functions.py", line 259, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/variables/functions.py", line 92, in call_function
    return tx.inline_user_function_return(
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 510, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1806, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1862, in inline_call_
    tracer.run()
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
    and self.step()
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
    getattr(self, inst.opname)(inst)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 342, in wrapper
    return inner_fn(self, inst)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1014, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 474, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/variables/torch.py", line 548, in call_function
    tensor_variable = wrap_fx_proxy(
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/variables/builder.py", line 754, in wrap_fx_proxy
    return wrap_fx_proxy_cls(
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/variables/builder.py", line 789, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx)
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 1173, in get_fake_value
    raise TorchRuntimeError() from e
torch._dynamo.exc.TorchRuntimeError: 

from user code:
   File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torchquantum/functional.py", line 1659, in cnot
    gate_wrapper(
  File "/home/yaoyao/miniconda3/lib/python3.8/site-packages/torchquantum/functional.py", line 258, in gate_wrapper
    params = torch.tensor(params, dtype=F_DTYPE)

Set torch._dynamo.config.verbose=True for more information

You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True
wcqc commented 1 year ago

Hi @yaoyaoding,

thanks for adding the operators. I will further debug the new errors.

When you say there are lots of fusion opportunities do you mean even w/o complete complex dtype support once the new errors/bugs are fixed, we should see a noticeable speedup with hidet with the above code? (Of course this is empirical, just asking do you mean this would be the case in principle?)

Thanks again.

yaoyaoding commented 1 year ago

Hi @wcqc,

That depends. I am not familar the commonly used operators in the quantum networks and what operators are the bottleneck. Usually, fusion can greatly reduce the memory access and speedup your network. But if the bottleneck is on the operator like large matrix multiplication, then it will diminish the speedup of fusion other small operators (say, 20% of time for those small operators, you can only get at most 1/0.8 speedup even you can optimize the 20% to zero).

Thanks for trying hidet, and let me know if you find out how to fix/avoid above errors (e.g., write the pytorch program in another way).