google-research / dex-lang

Research language for array processing in the Haskell/ML family
BSD 3-Clause "New" or "Revised" License
1.58k stars 107 forks source link

Compiler bug #1216

Closed pharringtonp19 closed 1 year ago

pharringtonp19 commented 1 year ago
def f(x:Float): Float = 
    x + x 

def compose (f: Float -> Float) (n:Nat) (x:Float) : Float =
    (lf, _) = yield_state (x, 0) \state.
        while do
            (acc, i) = get state
            if (i < n)
                then
                    state := (f acc, i + 1)
                    True
                else False
    lf

z = compose f 3 

z 2.0
16.

grad z 2.0

Compiler bug! Please report this at github.com/google-research/dex-lang/issues

not implemented: (while ( v#0:(Float32 & Word32) = get ref v#1:Word8 = %ilt (ProjectElt 1 v#0) 0x3 v#2:(|Unit | Unit|) = %toEnum (|Unit | Unit|) v#1 v#3:(|Unit | Unit|) = case v#2 of v#3 -> (0| () |) v#3 -> v#4:Float32 = %fadd (ProjectElt 0 v#0) (ProjectElt 0 v#0) v#5:Word32 = %iadd (ProjectElt 1 v#0) 0x1 v#6:Unit = ref := (v#4, v#5) (1| () |) case annotated with effects {State h_} v#4:Word8 = %dataConTag v#3 v#4)) CallStack (from HasCallStack): error, called at src/lib/Linearize.hs:597:8 in dex-0.1.0.0-56g7HgbAOkMJlMxH8BYl0J:Linearize

axch commented 1 year ago

Differentiation though while is not implemented (because it requires a List-typed tape, which is a hassle to make efficient). However, you don't need a while here, because the trip count is predictable:

def f(x:Float): Float = 
    x + x 

def compose (f: Float -> Float) (n:Nat) (x:Float) : Float =
  yield_state x \state.
    for _:(Fin n).
      state := f (get state)

z = compose f 3 

z 2.0
> 16.

grad z 2.0
> 8.
pharringtonp19 commented 1 year ago

@axch Thanks for the explanation!