dfdx / Umlaut.jl

The Code Tracer
MIT License
32 stars 6 forks source link

Insertion of statements at gotoifnot statements #26

Closed willtebbutt closed 1 year ago

willtebbutt commented 1 year ago

Hi. Thanks for creating this excellent package.

I'm interested in inserting an additional statement onto the tape every time that a GotoIfNot statement is hit in the IR. I can't see a straightforward way to achieve this under the current API -- it looks like Umlaut only supports transforming tapes, at which point I believe the information about where GotoIfNot statements appeared has been discarded. Have I understood this correctly?

If I am correct in my understanding, do you have any suggestions for how I might achieve this in a hacky way for now (e.g. I'm happy to dev this package and work with a local version for now)? I'm just looking to prototype stuff at the minute, so a solution which isn't officially supported, but which works locally, would be fine for me (and maybe we could look at upstreaming it at some point?)

Thanks!

willtebbutt commented 1 year ago

Upon closer inspection, it looks like it's as straightforward as replacing this line with something along the lines of

                v = t.tape[frame.ir2tape[cf.cond]].val
                push!(t.tape, mkcall(println, "GotoIfNot Node hit!"))
                v

Does this seem like a reasonable approach?

dfdx commented 1 year ago

Yes, the code snippet you provided looks valid to me. If you give more details about tge problem you are working on, I may also have better suggestion, e.g. how to do it without local cooy of the package.

willtebbutt commented 1 year ago

Fab. The modification does seem to be doing what I need, so that's good.

I'm trying to arrange that additional nodes are added to the linearised trace whenever value-dependent control flow occured in the original programme (I believe GotoIfNot nodes are the only relevant nodes here). These assertions should check that the same control path would be taken when the trace is executed at new inputs.

For example, the programme

function foo(x, n)
    for _ in 1:n
        x += 1
    end
    return x
end

produces different traces if you change the value of n. I want to ensure that if someone does the following

_, tape = trace(foo, 5.0, 5)
play!(tape, foo, 5.0, 4)

an error is thrown, saying that different control flow was taken.

dfdx commented 1 year ago

Yes, the approach you described seems to be the optimal one. I think we even had this logic previously, but I could have removed it during a refactoring and never re-introduce :( If you feel like merging you changes to this repo, I'll be happy to review!

By the way, we have a similar check for the number of output variables during unsplatting, so the concept of additional verification perfectly suits the package.

willtebbutt commented 1 year ago

Excellent -- I'll do that! I'll have a play around with a couple of design choices, and open a PR with one of them.