dfdx / Umlaut.jl

The Code Tracer
MIT License
32 stars 6 forks source link

Towards fixing splatting properly #52

Closed willtebbutt closed 11 months ago

willtebbutt commented 11 months ago

It turns out that while #48 didn't break any of the tests in Umlaut, it did break some examples of splatting in code that I'm testing on. This lead me to dig around a little, and construct the additional test cases for Umlaut that I've added in this PR.

Notably:

  1. the current implementation fails when you splat any data type which doesn't support getfield -- this is most things other than Tuple and NamedTuple.
  2. the previous implementation worked for anything for which getindex works, which is a greater array of things than getfield
  3. the previous implementation fails for iterators which don't have getindex defined on them. I've created one of these by adding a test involving a zip, for which getindex doesn't work.

I'm not entirely sure what the right answer is here. I would really rather not revert to the getindex implementation of unsplat!, because it puts a (potentially) non-primitive on the tape, and doesn't handle all of the cases.

One option is to, in the general case, insert a Core._apply_iterate(itr, tuple, x...) onto the tape, which will output a Tuple -- we could then employ the current strategy on that Tuple (getfield calls etc). This is a little bit annoying, but it at least has the property that it should handle every case. Downstream consumers of the tapes produced by Umlaut could then handle whichever range of options they prefer. We could then add special handlers for Tuples, Vectors`, etc.

What are your thoughts @dfdx ?

dfdx commented 11 months ago

I had a similar problem in Yota.jl when I needed getfield primitive protected from backprop mechanism. I solved it by introducing a special _getfield() function with the same behavior but different "primitiveness". Can we do the same thing here? E.g. introduce some internal function which is guaranteed to be a primitive and does what we need for all the types we care about?

willtebbutt commented 11 months ago

I like this idea. How about the following: for each argument which is getting splatted, we proceed in two steps:

  1. insert a Call which converts to a type which supports getfield (i.e. Tuple). This would be a primitive.
  2. apply the current Base.getfield to produce individual arguments.

The first step can be specialised, depending on the type. For example, if the type is already a Tuple or NamedTuple, then there's nothing that needs doing.

This approach has the limitation that the conversion to a Tuple might be hard to write a rule for, which would make life hard for Yota and my AD project, but we can always refactor in the future if we can think of an improvement.

edit: we could call the conversion function __convert_to_tuple_for_splatting__, or something similarly verbose which makes it clear that something interesting is going on.

dfdx commented 11 months ago

Yes, __convert_to_tuple_for_splatting__ is one option. Maybe a bit more intuitive option is to introduce __getfield__ which gets elements directly from the splatted data type, but I don't have enough examples in my mind to evaluate corner cases. Perhaps we just need to try out something and see how it goes.

willtebbutt commented 11 months ago

I see your point -- the problem is that not every iterable naturally supports a getfield / getindex-like function. The interface that they support is iterate, which lets you get the next element. For example,

x = Iterators.takewhile(>(0), [0.1, 0.1, 0.1, -0.1, 0.1, 0.1, 0.1, 0.1])

You can definitely splat x:

julia> tuple(x...)
(0.1, 0.1, 0.1)

so it's within the scope of what we're interested in, but length(x) and getindex(x, 1) both yield MethodErrors, because you can't tell what length(x) is without iterating over x and seeing when it ends. So the reason that I proposed to do this using two steps is to ensure that we iterate over a collection at most once, and that it makes sense to call our getfield-equivalent function on a data type that it makes sense to do so.

Does this seem reasonable, or am I missing something?

dfdx commented 11 months ago

Ah, right, I didn't take into account that we will need to iterate over collection multiple times. So yes, your solution looks like the best option.

willtebbutt commented 11 months ago

Cool. I'll work on that today.

codecov-commenter commented 11 months ago

Codecov Report

Attention: 8 lines in your changes are missing coverage. Please review.

Comparison is base (a82adaf) 1.27% compared to head (7868498) 0.74%.

Files Patch % Lines
src/trace.jl 0.00% 8 Missing :warning:

:exclamation: Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #52 +/- ## ======================================== - Coverage 1.27% 0.74% -0.53% ======================================== Files 8 7 -1 Lines 707 670 -37 ======================================== - Hits 9 5 -4 + Misses 698 665 -33 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

willtebbutt commented 11 months ago

@dfdx I've implemented the proposal -- I think we're good to go if you're happy

dfdx commented 11 months ago

Perfect, thank you!