llvm / circt

Circuit IR Compilers and Tools
https://circt.org
Other
1.67k stars 298 forks source link

[ExportVerilog] CIRCT may generate wrong RTL codes for if/else mux chain. #4419

Closed Siudya closed 1 year ago

Siudya commented 1 year ago

This issue is subsequency of #4399 OS: Ubuntu 20.04 CIRCT: SiFive Internal Release 1.24.0 Chisel 3.5.5 Description: The generated size of Verilog file has been reduced but it's function seems wrong. The .fir file is:

circuit MuxChain :
  module MuxChain :
    input clock : Clock
    input reset : UInt<1>
    output io : { flip valids : UInt<1>[4], flip index : UInt<2>, out : UInt<2>[4]}

    reg r : UInt<2>[4], clock with :
      reset => (UInt<1>("h0"), r) @[MuxChain.scala 11:14]
    when io.valids[0] : @[MuxChain.scala 14:13]
      r[io.index] <= UInt<1>("h1") @[MuxChain.scala 15:19]
    else :
      r[io.index] <= UInt<2>("h2") @[MuxChain.scala 17:19]
    when io.valids[1] : @[MuxChain.scala 14:13]
      r[io.index] <= UInt<1>("h1") @[MuxChain.scala 15:19]
    else :
      r[io.index] <= UInt<2>("h2") @[MuxChain.scala 17:19]
    when io.valids[2] : @[MuxChain.scala 14:13]
      r[io.index] <= UInt<1>("h1") @[MuxChain.scala 15:19]
    else :
      r[io.index] <= UInt<2>("h2") @[MuxChain.scala 17:19]
    when io.valids[3] : @[MuxChain.scala 14:13]
      r[io.index] <= UInt<1>("h1") @[MuxChain.scala 15:19]
    else :
      r[io.index] <= UInt<2>("h2") @[MuxChain.scala 17:19]
    io.out <= r @[MuxChain.scala 21:10]

When SFC is used, the .sv file is like this:

module MuxChain(
  input        clock,
  input        reset,
  input        io_valids_0,
  input        io_valids_1,
  input        io_valids_2,
  input        io_valids_3,
  input  [1:0] io_index,
  output [1:0] io_out_0,
  output [1:0] io_out_1,
  output [1:0] io_out_2,
  output [1:0] io_out_3
);
  reg [1:0] r_0; // @[MuxChain.scala 11:14]
  reg [1:0] r_1; // @[MuxChain.scala 11:14]
  reg [1:0] r_2; // @[MuxChain.scala 11:14]
  reg [1:0] r_3; // @[MuxChain.scala 11:14]
  wire [1:0] _GEN_0 = 2'h0 == io_index ? 2'h1 : r_0; // @[MuxChain.scala 11:14 15:{19,19}]
  wire [1:0] _GEN_1 = 2'h1 == io_index ? 2'h1 : r_1; // @[MuxChain.scala 11:14 15:{19,19}]
  wire [1:0] _GEN_2 = 2'h2 == io_index ? 2'h1 : r_2; // @[MuxChain.scala 11:14 15:{19,19}]
  wire [1:0] _GEN_3 = 2'h3 == io_index ? 2'h1 : r_3; // @[MuxChain.scala 11:14 15:{19,19}]
  wire [1:0] _GEN_4 = 2'h0 == io_index ? 2'h2 : r_0; // @[MuxChain.scala 11:14 17:{19,19}]
  wire [1:0] _GEN_5 = 2'h1 == io_index ? 2'h2 : r_1; // @[MuxChain.scala 11:14 17:{19,19}]
  wire [1:0] _GEN_6 = 2'h2 == io_index ? 2'h2 : r_2; // @[MuxChain.scala 11:14 17:{19,19}]
  wire [1:0] _GEN_7 = 2'h3 == io_index ? 2'h2 : r_3; // @[MuxChain.scala 11:14 17:{19,19}]
  wire [1:0] _GEN_8 = io_valids_0 ? _GEN_0 : _GEN_4; // @[MuxChain.scala 14:13]
  wire [1:0] _GEN_9 = io_valids_0 ? _GEN_1 : _GEN_5; // @[MuxChain.scala 14:13]
  wire [1:0] _GEN_10 = io_valids_0 ? _GEN_2 : _GEN_6; // @[MuxChain.scala 14:13]
  wire [1:0] _GEN_11 = io_valids_0 ? _GEN_3 : _GEN_7; // @[MuxChain.scala 14:13]
  wire [1:0] _GEN_12 = 2'h0 == io_index ? 2'h1 : _GEN_8; // @[MuxChain.scala 15:{19,19}]
  wire [1:0] _GEN_13 = 2'h1 == io_index ? 2'h1 : _GEN_9; // @[MuxChain.scala 15:{19,19}]
  wire [1:0] _GEN_14 = 2'h2 == io_index ? 2'h1 : _GEN_10; // @[MuxChain.scala 15:{19,19}]
  wire [1:0] _GEN_15 = 2'h3 == io_index ? 2'h1 : _GEN_11; // @[MuxChain.scala 15:{19,19}]
  wire [1:0] _GEN_16 = 2'h0 == io_index ? 2'h2 : _GEN_8; // @[MuxChain.scala 17:{19,19}]
  wire [1:0] _GEN_17 = 2'h1 == io_index ? 2'h2 : _GEN_9; // @[MuxChain.scala 17:{19,19}]
  wire [1:0] _GEN_18 = 2'h2 == io_index ? 2'h2 : _GEN_10; // @[MuxChain.scala 17:{19,19}]
  wire [1:0] _GEN_19 = 2'h3 == io_index ? 2'h2 : _GEN_11; // @[MuxChain.scala 17:{19,19}]
  wire [1:0] _GEN_20 = io_valids_1 ? _GEN_12 : _GEN_16; // @[MuxChain.scala 14:13]
  wire [1:0] _GEN_21 = io_valids_1 ? _GEN_13 : _GEN_17; // @[MuxChain.scala 14:13]
  wire [1:0] _GEN_22 = io_valids_1 ? _GEN_14 : _GEN_18; // @[MuxChain.scala 14:13]
  wire [1:0] _GEN_23 = io_valids_1 ? _GEN_15 : _GEN_19; // @[MuxChain.scala 14:13]
  wire [1:0] _GEN_24 = 2'h0 == io_index ? 2'h1 : _GEN_20; // @[MuxChain.scala 15:{19,19}]
  wire [1:0] _GEN_25 = 2'h1 == io_index ? 2'h1 : _GEN_21; // @[MuxChain.scala 15:{19,19}]
  wire [1:0] _GEN_26 = 2'h2 == io_index ? 2'h1 : _GEN_22; // @[MuxChain.scala 15:{19,19}]
  wire [1:0] _GEN_27 = 2'h3 == io_index ? 2'h1 : _GEN_23; // @[MuxChain.scala 15:{19,19}]
  wire [1:0] _GEN_28 = 2'h0 == io_index ? 2'h2 : _GEN_20; // @[MuxChain.scala 17:{19,19}]
  wire [1:0] _GEN_29 = 2'h1 == io_index ? 2'h2 : _GEN_21; // @[MuxChain.scala 17:{19,19}]
  wire [1:0] _GEN_30 = 2'h2 == io_index ? 2'h2 : _GEN_22; // @[MuxChain.scala 17:{19,19}]
  wire [1:0] _GEN_31 = 2'h3 == io_index ? 2'h2 : _GEN_23; // @[MuxChain.scala 17:{19,19}]
  wire [1:0] _GEN_32 = io_valids_2 ? _GEN_24 : _GEN_28; // @[MuxChain.scala 14:13]
  wire [1:0] _GEN_33 = io_valids_2 ? _GEN_25 : _GEN_29; // @[MuxChain.scala 14:13]
  wire [1:0] _GEN_34 = io_valids_2 ? _GEN_26 : _GEN_30; // @[MuxChain.scala 14:13]
  wire [1:0] _GEN_35 = io_valids_2 ? _GEN_27 : _GEN_31; // @[MuxChain.scala 14:13]
  assign io_out_0 = r_0; // @[MuxChain.scala 21:10]
  assign io_out_1 = r_1; // @[MuxChain.scala 21:10]
  assign io_out_2 = r_2; // @[MuxChain.scala 21:10]
  assign io_out_3 = r_3; // @[MuxChain.scala 21:10]
  always @(posedge clock) begin
    if (io_valids_3) begin // @[MuxChain.scala 14:13]
      if (2'h0 == io_index) begin // @[MuxChain.scala 15:19]
        r_0 <= 2'h1; // @[MuxChain.scala 15:19]
      end else begin
        r_0 <= _GEN_32;
      end
    end else if (2'h0 == io_index) begin // @[MuxChain.scala 17:19]
      r_0 <= 2'h2; // @[MuxChain.scala 17:19]
    end else begin
      r_0 <= _GEN_32;
    end
    if (io_valids_3) begin // @[MuxChain.scala 14:13]
      if (2'h1 == io_index) begin // @[MuxChain.scala 15:19]
        r_1 <= 2'h1; // @[MuxChain.scala 15:19]
      end else begin
        r_1 <= _GEN_33;
      end
    end else if (2'h1 == io_index) begin // @[MuxChain.scala 17:19]
      r_1 <= 2'h2; // @[MuxChain.scala 17:19]
    end else begin
      r_1 <= _GEN_33;
    end
    if (io_valids_3) begin // @[MuxChain.scala 14:13]
      if (2'h2 == io_index) begin // @[MuxChain.scala 15:19]
        r_2 <= 2'h1; // @[MuxChain.scala 15:19]
      end else begin
        r_2 <= _GEN_34;
      end
    end else if (2'h2 == io_index) begin // @[MuxChain.scala 17:19]
      r_2 <= 2'h2; // @[MuxChain.scala 17:19]
    end else begin
      r_2 <= _GEN_34;
    end
    if (io_valids_3) begin // @[MuxChain.scala 14:13]
      if (2'h3 == io_index) begin // @[MuxChain.scala 15:19]
        r_3 <= 2'h1; // @[MuxChain.scala 15:19]
      end else begin
        r_3 <= _GEN_35;
      end
    end else if (2'h3 == io_index) begin // @[MuxChain.scala 17:19]
      r_3 <= 2'h2; // @[MuxChain.scala 17:19]
    end else begin
      r_3 <= _GEN_35;
    end
  end
endmodule

When MFC is used, the .sv file is like this:

// Generated by CIRCT sifive/1/24/0
module MuxChain(    // <stdin>:3:10
  input        clock,
               reset,
               io_valids_0,
               io_valids_1,
               io_valids_2,
               io_valids_3,
  input  [1:0] io_index,
  output [1:0] io_out_0,
               io_out_1,
               io_out_2,
               io_out_3);

  reg [1:0] r_0;    // MuxChain.scala:11:14
  reg [1:0] r_1;    // MuxChain.scala:11:14
  reg [1:0] r_2;    // MuxChain.scala:11:14
  reg [1:0] r_3;    // MuxChain.scala:11:14
  always @(posedge clock) begin
    automatic logic [1:0] _GEN = io_valids_3 ? 2'h1 : 2'h2; // MuxChain.scala:14:13, :15:19, :17:19
    r_0 <= io_index == 2'h0 ? _GEN : r_0;   // MuxChain.scala:11:14, :14:13, :15:19, :17:19
    r_1 <= io_index == 2'h1 ? _GEN : r_1;   // MuxChain.scala:11:14, :14:13, :15:19, :17:19
    r_2 <= io_index == 2'h2 ? _GEN : r_2;   // MuxChain.scala:11:14, :14:13, :15:19, :17:19
    r_3 <= (&io_index) ? _GEN : r_3;    // MuxChain.scala:11:14, :14:13, :15:19, :17:19
  end // always @(posedge)
  assign io_out_0 = r_0;    // <stdin>:3:10, MuxChain.scala:11:14
  assign io_out_1 = r_1;    // <stdin>:3:10, MuxChain.scala:11:14
  assign io_out_2 = r_2;    // <stdin>:3:10, MuxChain.scala:11:14
  assign io_out_3 = r_3;    // <stdin>:3:10, MuxChain.scala:11:14
endmodule

MFC unexpectedly overrides the previous write conditions with the last write condition, which should be treat as some writes with priorty (mux chains).

Repetition: The Chisel source codes is:

import chisel3._
import chisel3.util._
class MuxChain(aLen: Int = 4, w: Int = 2) extends Module {
  val io = IO(new Bundle() {
    val valids = Input(Vec(aLen, Bool()))
    val index = Input(UInt(log2Up(aLen).W))
    val out = Output(Vec(aLen, UInt(w.W)))
  })

  val r = Reg(Vec(aLen, UInt(w.W)))

  for ((v, i) <- io.valids.zipWithIndex) {
    when(v) {
      r(io.index) := 1.U
    }.otherwise({
      r(io.index) := 2.U
    })
  }
  io.out := r
}

Run this to generated Verilog-from-SFC codes:

object GenMuxChainSFC extends App {
  val iargs = Seq("-td", "build", "-X", "sverilog", "-E", "sverilog", "--emission-options","disableRegisterRandomization")
  (new chisel3.stage.ChiselStage).execute(iargs.toArray,
    Seq(
      ChiselGeneratorAnnotation(() => new MuxChain(4, 2)),
    )
  )
}

Run this line to generated Verilog-from-MFC codes:

firtool -format=fir --disable-all-randomization --dedup --verilog -o build/MuxChainMFC.sv build/MuxChain.fir
Siudya commented 1 year ago

Result of CIRCT:1.15.0 seems good.

// Generated by CIRCT sifive/1/15/0
// Standard header to adapt well known macros to our needs.
`ifdef RANDOMIZE_REG_INIT
  `define RANDOMIZE
`endif

// RANDOM may be set to an expression that produces a 32-bit random unsigned value.
`ifndef RANDOM
  `define RANDOM $random
`endif

// Users can define INIT_RANDOM as general code that gets injected into the
// initializer block for modules with registers.
`ifndef INIT_RANDOM
  `define INIT_RANDOM
`endif

// If using random initialization, you can also define RANDOMIZE_DELAY to
// customize the delay used, otherwise 0.002 is used.
`ifndef RANDOMIZE_DELAY
  `define RANDOMIZE_DELAY 0.002
`endif

// Define INIT_RANDOM_PROLOG_ for use in our modules below.
`ifdef RANDOMIZE
  `ifdef VERILATOR
    `define INIT_RANDOM_PROLOG_ `INIT_RANDOM
  `else
    `define INIT_RANDOM_PROLOG_ `INIT_RANDOM #`RANDOMIZE_DELAY begin end
  `endif
`else
  `define INIT_RANDOM_PROLOG_
`endif

module MuxChain(
  input        clock,
               reset,
               io_valids_0,
               io_valids_1,
               io_valids_2,
               io_valids_3,
  input  [1:0] io_index,
  output [1:0] io_out_0,
               io_out_1,
               io_out_2,
               io_out_3);

  reg [1:0] r_0;    // MuxChain.scala:11:14
  reg [1:0] r_1;    // MuxChain.scala:11:14
  reg [1:0] r_2;    // MuxChain.scala:11:14
  reg [1:0] r_3;    // MuxChain.scala:11:14
  always @(posedge clock) begin
    automatic logic       _GEN = io_index == 2'h0;  // MuxChain.scala:15:19
    automatic logic       _GEN_0 = io_index == 2'h1;    // MuxChain.scala:15:19
    automatic logic       _GEN_1 = io_index == 2'h2;    // MuxChain.scala:15:19, :17:19
    automatic logic [1:0] _GEN_2;   // MuxChain.scala:14:13, :15:19, :17:19
    automatic logic [1:0] _GEN_3;   // MuxChain.scala:14:13, :15:19, :17:19
    automatic logic [1:0] _GEN_4;   // MuxChain.scala:14:13, :15:19, :17:19
    automatic logic [1:0] _GEN_5;   // MuxChain.scala:14:13, :15:19, :17:19
    automatic logic [1:0] _GEN_6;   // MuxChain.scala:14:13, :15:19, :17:19
    automatic logic [1:0] _GEN_7;   // MuxChain.scala:14:13, :15:19, :17:19
    automatic logic [1:0] _GEN_8;   // MuxChain.scala:14:13, :15:19, :17:19
    automatic logic [1:0] _GEN_9;   // MuxChain.scala:14:13, :15:19, :17:19
    automatic logic [1:0] _GEN_10;  // MuxChain.scala:14:13, :15:19, :17:19
    automatic logic [1:0] _GEN_11;  // MuxChain.scala:14:13, :15:19, :17:19
    automatic logic [1:0] _GEN_12;  // MuxChain.scala:14:13, :15:19, :17:19
    automatic logic [1:0] _GEN_13;  // MuxChain.scala:14:13, :15:19, :17:19
    _GEN_2 = io_valids_0 ? (_GEN ? 2'h1 : r_0) : _GEN ? 2'h2 : r_0; // MuxChain.scala:11:14, :14:13, :15:19, :17:19
    _GEN_3 = io_valids_0 ? (_GEN_0 ? 2'h1 : r_1) : _GEN_0 ? 2'h2 : r_1; // MuxChain.scala:11:14, :14:13, :15:19, :17:19
    _GEN_4 = io_valids_0 ? (_GEN_1 ? 2'h1 : r_2) : _GEN_1 ? 2'h2 : r_2; // MuxChain.scala:11:14, :14:13, :15:19, :17:19
    _GEN_5 = io_valids_0 ? (&io_index ? 2'h1 : r_3) : &io_index ? 2'h2 : r_3;   // MuxChain.scala:11:14, :14:13, :15:19, :17:19
    _GEN_6 = io_valids_1 ? (_GEN ? 2'h1 : _GEN_2) : _GEN ? 2'h2 : _GEN_2;   // MuxChain.scala:14:13, :15:19, :17:19
    _GEN_7 = io_valids_1 ? (_GEN_0 ? 2'h1 : _GEN_3) : _GEN_0 ? 2'h2 : _GEN_3;   // MuxChain.scala:14:13, :15:19, :17:19
    _GEN_8 = io_valids_1 ? (_GEN_1 ? 2'h1 : _GEN_4) : _GEN_1 ? 2'h2 : _GEN_4;   // MuxChain.scala:14:13, :15:19, :17:19
    _GEN_9 = io_valids_1 ? (&io_index ? 2'h1 : _GEN_5) : &io_index ? 2'h2 : _GEN_5; // MuxChain.scala:14:13, :15:19, :17:19
    _GEN_10 = io_valids_2 ? (_GEN ? 2'h1 : _GEN_6) : _GEN ? 2'h2 : _GEN_6;  // MuxChain.scala:14:13, :15:19, :17:19
    _GEN_11 = io_valids_2 ? (_GEN_0 ? 2'h1 : _GEN_7) : _GEN_0 ? 2'h2 : _GEN_7;  // MuxChain.scala:14:13, :15:19, :17:19
    _GEN_12 = io_valids_2 ? (_GEN_1 ? 2'h1 : _GEN_8) : _GEN_1 ? 2'h2 : _GEN_8;  // MuxChain.scala:14:13, :15:19, :17:19
    _GEN_13 = io_valids_2 ? (&io_index ? 2'h1 : _GEN_9) : &io_index ? 2'h2 : _GEN_9;    // MuxChain.scala:14:13, :15:19, :17:19
    r_0 <= io_valids_3 ? (_GEN ? 2'h1 : _GEN_10) : _GEN ? 2'h2 : _GEN_10;   // MuxChain.scala:11:14, :14:13, :15:19, :17:19
    r_1 <= io_valids_3 ? (_GEN_0 ? 2'h1 : _GEN_11) : _GEN_0 ? 2'h2 : _GEN_11;   // MuxChain.scala:11:14, :14:13, :15:19, :17:19
    r_2 <= io_valids_3 ? (_GEN_1 ? 2'h1 : _GEN_12) : _GEN_1 ? 2'h2 : _GEN_12;   // MuxChain.scala:11:14, :14:13, :15:19, :17:19
    r_3 <= io_valids_3 ? (&io_index ? 2'h1 : _GEN_13) : &io_index ? 2'h2 : _GEN_13; // MuxChain.scala:11:14, :14:13, :15:19, :17:19
  end // always @(posedge)
  `ifndef SYNTHESIS
    `ifdef FIRRTL_BEFORE_INITIAL
      `FIRRTL_BEFORE_INITIAL
    `endif
    initial begin
      automatic logic [31:0] _RANDOM_0;
      `ifdef INIT_RANDOM_PROLOG_
        `INIT_RANDOM_PROLOG_
      `endif
      `ifdef RANDOMIZE_REG_INIT
        _RANDOM_0 = `RANDOM;
        r_0 = _RANDOM_0[1:0];   // MuxChain.scala:11:14
        r_1 = _RANDOM_0[3:2];   // MuxChain.scala:11:14
        r_2 = _RANDOM_0[5:4];   // MuxChain.scala:11:14
        r_3 = _RANDOM_0[7:6];   // MuxChain.scala:11:14
      `endif
    end // initial
    `ifdef FIRRTL_AFTER_INITIAL
      `FIRRTL_AFTER_INITIAL
    `endif
  `endif
  assign io_out_0 = r_0;    // MuxChain.scala:11:14
  assign io_out_1 = r_1;    // MuxChain.scala:11:14
  assign io_out_2 = r_2;    // MuxChain.scala:11:14
  assign io_out_3 = r_3;    // MuxChain.scala:11:14
endmodule
seldridge commented 1 year ago

I was pretty sure that there is no issue here and it's just that @darthscsi's mux folders added in https://github.com/llvm/circt/pull/4403 and https://github.com/llvm/circt/pull/4405 are just really removing a lot of code. To check this, I ran the SFC and MFC Verilog through Formality and they are formally equivalent.

You can get some idea of what these folders are doing from the code comments here and here. Note: that the input .fir is somewhat confusing as all the when/else blocks, except for the last one, are dead code due to last connect semantics. I realize that I wrote the code example this way in #4399. That code example wasn't supposed to make sense, just it was the first thing that I found which properly exercised the performance bug.