leanprover / lean4

Lean 4 programming language and theorem prover
https://lean-lang.org
Apache License 2.0
3.88k stars 329 forks source link

Inferred value for implicit argument leads to suboptimal IR #4157

Open TwoFX opened 2 weeks ago

TwoFX commented 2 weeks ago

Prerequisites

Please put an X between the brackets as you perform the following steps:

Description

Consider the following code

-- noinline is just for clarity, does not affect the issue
@[noinline] def g {a : Nat} (_ha : a < 5) : Nat :=
  a + 2

set_option trace.compiler.ir.reset_reuse true
def f (p : {p : Nat × Nat // p.1 < 5}) : Nat × Nat :=
  let ⟨⟨a, b⟩, ha⟩ := p
  let b' := g ha
  if b' < 5 then
    ⟨a + 1, b'⟩
  else
    ⟨a, b⟩

The generated IR looks like this:

[reset_reuse]
def f (x_1 : obj) : obj :=
  case x_1 : obj of
  Prod.mk →
    let x_2 : obj := proj[0] x_1;
    let x_3 : obj := proj[1] x_1;
    let x_11 : obj := reset[2] x_1;
    let x_4 : obj := reuse x_11 in ctor_0[Prod.mk] x_2 x_3;
    let x_5 : obj := g x_2 ◾;
    let x_6 : obj := 5;
    let x_7 : u8 := Nat.decLt x_5 x_6;
    case x_7 : obj of
    Bool.false →
      ret x_4
    Bool.true →
      let x_8 : obj := 1;
      let x_9 : obj := Nat.add x_2 x_8;
      let x_10 : obj := ctor_0[Prod.mk] x_9 x_5;
      ret x_10

This is suboptimal: we always build x_4 which is the same as x_1 anyway, even though it is only needed in the else branch of the if. Later stages of the IR pipeline optimize the constructor call away, so we are left with

[result]
def f (x_1 : obj) : obj :=
  let x_2 : u8 := isShared x_1;
  case x_2 : u8 of
  Bool.false →
    let x_3 : obj := proj[0] x_1;
    inc x_3;
    let x_4 : obj := g x_3 ◾;
    let x_5 : obj := 5;
    let x_6 : u8 := Nat.decLt x_4 x_5;
    case x_6 : u8 of
    Bool.false →
      dec x_4;
      dec x_3;
      ret x_1
    Bool.true →
      dec x_1;
      let x_7 : obj := 1;
      let x_8 : obj := Nat.add x_3 x_7;
      dec x_3;
      let x_9 : obj := ctor_0[Prod.mk] x_8 x_4;
      ret x_9
  Bool.true →
    let x_10 : obj := proj[0] x_1;
    let x_11 : obj := proj[1] x_1;
    inc x_11;
    inc x_10;
    dec x_1;
    inc x_10;
    let x_12 : obj := ctor_0[Prod.mk] x_10 x_11;
    let x_13 : obj := g x_10 ◾;
    let x_14 : obj := 5;
    let x_15 : u8 := Nat.decLt x_13 x_14;
    case x_15 : u8 of
    Bool.false →
      dec x_13;
      dec x_10;
      ret x_12
    Bool.true →
      dec x_12;
      let x_16 : obj := 1;
      let x_17 : obj := Nat.add x_10 x_16;
      dec x_10;
      let x_18 : obj := ctor_0[Prod.mk] x_17 x_13;
      ret x_18

but this is still bad because we don't get to reuse the memory cell of x_1 in the true branch of the if in the function. The reason seems to be connected to the fact that we infer (a, b).fst as the implicit argument to g. If the line let b' := g ha is changed to let b' := g (a := a) ha, then we get the expected IR

[reset_reuse]
def f (x_1 : obj) : obj :=
  case x_1 : obj of
  Prod.mk →
    let x_2 : obj := proj[0] x_1;
    let x_3 : obj := proj[1] x_1;
    let x_11 : obj := reset[2] x_1;
    let x_4 : obj := g x_2 ◾;
    let x_5 : obj := 5;
    let x_6 : u8 := Nat.decLt x_4 x_5;
    case x_6 : obj of
    Bool.false →
      let x_7 : obj := reuse x_11 in ctor_0[Prod.mk] x_2 x_3;
      ret x_7
    Bool.true →
      let x_8 : obj := 1;
      let x_9 : obj := Nat.add x_2 x_8;
      let x_10 : obj := reuse x_11 in ctor_0[Prod.mk] x_9 x_4;
      ret x_10
[result]
def f (x_1 : obj) : obj :=
  let x_2 : u8 := isShared x_1;
  case x_2 : u8 of
  Bool.false →
    let x_3 : obj := proj[0] x_1;
    let x_4 : obj := proj[1] x_1;
    let x_5 : obj := g x_3 ◾;
    let x_6 : obj := 5;
    let x_7 : u8 := Nat.decLt x_5 x_6;
    case x_7 : u8 of
    Bool.false →
      dec x_5;
      ret x_1
    Bool.true →
      dec x_4;
      let x_8 : obj := 1;
      let x_9 : obj := Nat.add x_3 x_8;
      dec x_3;
      set x_1[1] := x_5;
      set x_1[0] := x_9;
      ret x_1
  Bool.true →
    let x_10 : obj := proj[0] x_1;
    let x_11 : obj := proj[1] x_1;
    inc x_11;
    inc x_10;
    dec x_1;
    let x_12 : obj := g x_10 ◾;
    let x_13 : obj := 5;
    let x_14 : u8 := Nat.decLt x_12 x_13;
    case x_14 : u8 of
    Bool.false →
      dec x_12;
      let x_15 : obj := ctor_0[Prod.mk] x_10 x_11;
      ret x_15
    Bool.true →
      dec x_11;
      let x_16 : obj := 1;
      let x_17 : obj := Nat.add x_10 x_16;
      dec x_10;
      let x_18 : obj := ctor_0[Prod.mk] x_17 x_12;
      ret x_18

Context

This is a minimization of an issue I observed while working on the hash map, see my comment there.

Steps to Reproduce

  1. Run the code above and look at the trace output

Expected behavior: Quality of the generated IR should not depend on whether we used (a, b).1 or a for the implicit argument, given that the IR clearly shows that in both cases we just pass a to g.

Actual behavior: Suboptimal IR in case (a, b).1 is used as the implicit argument.

Versions

4.9.0-nightly-2024-05-11

Additional Information

[Additional information, configuration or data that might be necessary to reproduce the issue]

Impact

Add :+1: to issues you consider important. If others are impacted by this issue, please ask them to add :+1: to it.