FluxML / IRTools.jl

Mike's Little Intermediate Representation
MIT License
111 stars 35 forks source link

Incorrect `while` loop #94

Closed wazizian closed 2 years ago

wazizian commented 2 years ago

Hi, I encountered an issue with IRTools which can be reduced to the following example

using IRTools
using IRTools: @dynamo, recurse!, IR

@dynamo function passthrough(a...)
    ir = IR(a...)
    ir == nothing && return
    recurse!(ir)
    return ir
end

function f(x)
    i = x
    while i > 0
        i -= 1
    end
    return x
end

@show f(1)
@show passthrough(f, 1)

But, with both Julia 1.7.1 and 1.8 (see versions below), this outputs

f(1) = 1
passthrough(f, 1) = 0

Moreover, this also translates into a Zygote issue

using Zygote

function g(y)
    return f(1) * y
end

@show gradient(g, 1) 

which gives

gradient(g, 1) = (0.0,)

I am open to investigating this issue further but I am a bit clueless right now. In particular, I am not sure whether it is a Julia or IRTools issue, and if it is a Julia issue, of where to start.

Thank you in advance for your help (and thank you for publishing IRTools !), Cheers, Waïss

How I arrived at this MWE

I began by investigating an incorrect behavior in the string function. Indeed, when a @dynamo is used, the function write(io::IO, s::String) do not correctly update io.ptr. I narrowed it down to the variable written in unsafe_write (here and called there) being incorrectly decreased in the while loop, akin to what happens in the example above.

Version info

I reproduced this bug on both Julia 1.7.1 and Julia 1.8 nightly.

Julia Version 1.7.1
Commit ac5cc99908 (2021-12-22 19:35 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Core(TM) i7-10700 CPU @ 2.90GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-12.0.1 (ORCJIT, skylake)
Julia Version 1.8.0-DEV.1436
Commit a7beb93dfe (2022-01-31 00:10 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Core(TM) i7-10700 CPU @ 2.90GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.0 (ORCJIT, skylake)

(Edit: fix code formatting)

mcabbott commented 2 years ago

This seems related to a Zygote bug, for which I don't remember the reference. Replacing i = x with i = identity(x) avoids the problem:

function f2(x)
    i = identity(x)
    while i > 0
        i -= 1
    end
    return x
end

passthrough(f2, 1) == f2(1) == 1

function g2(y)
    return f2(1) * y
end

gradient(g2, 1) == (1.0,)

My belief is that this is an IRTools problem, which would be great to solve!