DeepWok / mase

Machine-Learning Accelerator System Exploration Tools
Other
104 stars 52 forks source link

RuntimeError: a leaf Variable that requires grad is being used in an in-place operation. #180

Open tring27 opened 1 month ago

tring27 commented 1 month ago

I am trying to repro Lab 4: https://github.com/DeepWok/mase/blob/main/docs/labs/lab4-hardware.ipynb I am using commit ID: 047f27b9b156a7575b416c1156ea571c229cf9e8, where I updated the test_verilog_analysis_pass.

Steps to repro:

  1. create a python file my.py
import os, sys

from chop.ir.graph.mase_graph import MaseGraph

from chop.passes.graph.analysis import (
    init_metadata_analysis_pass,
    add_common_metadata_analysis_pass,
    add_hardware_metadata_analysis_pass,
    report_node_type_analysis_pass,
    test_verilog_analysis_pass,
)

from chop.passes.graph.transforms import (
    emit_verilog_top_transform_pass,
    emit_internal_rtl_transform_pass,
    emit_bram_transform_pass,
    emit_cocotb_transform_pass,
    quantize_transform_pass,
)

from chop.tools.logger import set_logging_verbosity

set_logging_verbosity("debug")

import toml
import torch
import torch.nn as nn

class MLP(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.fc1 = nn.Linear(4, 10)

    def forward(self, x):
        x = torch.flatten(x, start_dim=1, end_dim=-1)
        x = torch.nn.functional.relu(self.fc1(x))
        return x

mlp = MLP()
mg = MaseGraph(model=mlp)

batch_size = 1
x = torch.randn((batch_size, 2, 2))
dummy_in = {"x": x}

mg, _ = init_metadata_analysis_pass(mg, None)
mg, _ = add_common_metadata_analysis_pass(
    mg, {"dummy_in": dummy_in, "add_value": False}
)

config_file = os.path.join(
    os.path.abspath(""),
    #"..",
    "..",
    "machop",
    "configs",
    "tests",
    "quantize",
    "fixed.toml",
)
with open(config_file, "r") as f:
    quan_args = toml.load(f)["passes"]["quantize"]
print("QUAN ARGS.....")
print(quan_args)
mg, _ = quantize_transform_pass(mg, quan_args)

_ = report_node_type_analysis_pass(mg)

for node in mg.fx_graph.nodes:
    for arg, arg_info in node.meta["mase"]["common"]["args"].items():
        if isinstance(arg_info, dict):
            arg_info["type"] = "fixed"
            arg_info["precision"] = [8, 3]
    for result, result_info in node.meta["mase"]["common"]["results"].items():
        if isinstance(result_info, dict):
            result_info["type"] = "fixed"
            result_info["precision"] = [8, 3]

mg, _ = add_hardware_metadata_analysis_pass(mg, None)
mg, _ = emit_verilog_top_transform_pass(mg)
mg, _ = emit_internal_rtl_transform_pass(mg)

mg, _ = emit_bram_transform_pass(mg)
mg, _ = test_verilog_analysis_pass(mg)
  1. python my.py
ChengZhang-98 commented 1 month ago

I am trying to repro Lab 4: https://github.com/DeepWok/mase/blob/main/docs/labs/lab4-hardware.ipynb I am using commit ID: 047f27b, where I updated the test_verilog_analysis_pass.

Steps to repro:

  1. create a python file my.py
import os, sys

from chop.ir.graph.mase_graph import MaseGraph

from chop.passes.graph.analysis import (
    init_metadata_analysis_pass,
    add_common_metadata_analysis_pass,
    add_hardware_metadata_analysis_pass,
    report_node_type_analysis_pass,
    test_verilog_analysis_pass,
)

from chop.passes.graph.transforms import (
    emit_verilog_top_transform_pass,
    emit_internal_rtl_transform_pass,
    emit_bram_transform_pass,
    emit_cocotb_transform_pass,
    quantize_transform_pass,
)

from chop.tools.logger import set_logging_verbosity

set_logging_verbosity("debug")

import toml
import torch
import torch.nn as nn

class MLP(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.fc1 = nn.Linear(4, 10)

    def forward(self, x):
        x = torch.flatten(x, start_dim=1, end_dim=-1)
        x = torch.nn.functional.relu(self.fc1(x))
        return x

mlp = MLP()
mg = MaseGraph(model=mlp)

batch_size = 1
x = torch.randn((batch_size, 2, 2))
dummy_in = {"x": x}

mg, _ = init_metadata_analysis_pass(mg, None)
mg, _ = add_common_metadata_analysis_pass(
    mg, {"dummy_in": dummy_in, "add_value": False}
)

config_file = os.path.join(
    os.path.abspath(""),
    #"..",
    "..",
    "machop",
    "configs",
    "tests",
    "quantize",
    "fixed.toml",
)
with open(config_file, "r") as f:
    quan_args = toml.load(f)["passes"]["quantize"]
print("QUAN ARGS.....")
print(quan_args)
mg, _ = quantize_transform_pass(mg, quan_args)

_ = report_node_type_analysis_pass(mg)

for node in mg.fx_graph.nodes:
    for arg, arg_info in node.meta["mase"]["common"]["args"].items():
        if isinstance(arg_info, dict):
            arg_info["type"] = "fixed"
            arg_info["precision"] = [8, 3]
    for result, result_info in node.meta["mase"]["common"]["results"].items():
        if isinstance(result_info, dict):
            result_info["type"] = "fixed"
            result_info["precision"] = [8, 3]

mg, _ = add_hardware_metadata_analysis_pass(mg, None)
mg, _ = emit_verilog_top_transform_pass(mg)
mg, _ = emit_internal_rtl_transform_pass(mg)

mg, _ = emit_bram_transform_pass(mg)
mg, _ = test_verilog_analysis_pass(mg)
  1. python my.py

Could you try adding torch.set_grad_enabled(False) after importing packages? I think the raised error is because the passes are changing weights while pytorch autograd is tracing the computation graph for backpropagation. Autograd is not necessary for verilog emission.