Closed willtebbutt closed 11 months ago
I had a similar problem in Yota.jl when I needed getfield
primitive protected from backprop mechanism. I solved it by introducing a special _getfield()
function with the same behavior but different "primitiveness". Can we do the same thing here? E.g. introduce some internal function which is guaranteed to be a primitive and does what we need for all the types we care about?
I like this idea. How about the following: for each argument which is getting splatted, we proceed in two steps:
Call
which converts to a type which supports getfield
(i.e. Tuple
). This would be a primitive.Base.getfield
to produce individual arguments.The first step can be specialised, depending on the type. For example, if the type is already a Tuple
or NamedTuple
, then there's nothing that needs doing.
This approach has the limitation that the conversion to a Tuple
might be hard to write a rule for, which would make life hard for Yota
and my AD project, but we can always refactor in the future if we can think of an improvement.
edit: we could call the conversion function __convert_to_tuple_for_splatting__
, or something similarly verbose which makes it clear that something interesting is going on.
Yes, __convert_to_tuple_for_splatting__
is one option. Maybe a bit more intuitive option is to introduce __getfield__
which gets elements directly from the splatted data type, but I don't have enough examples in my mind to evaluate corner cases. Perhaps we just need to try out something and see how it goes.
I see your point -- the problem is that not every iterable naturally supports a getfield
/ getindex
-like function. The interface that they support is iterate
, which lets you get the next element. For example,
x = Iterators.takewhile(>(0), [0.1, 0.1, 0.1, -0.1, 0.1, 0.1, 0.1, 0.1])
You can definitely splat x
:
julia> tuple(x...)
(0.1, 0.1, 0.1)
so it's within the scope of what we're interested in, but length(x)
and getindex(x, 1)
both yield MethodError
s, because you can't tell what length(x)
is without iterating over x
and seeing when it ends. So the reason that I proposed to do this using two steps is to ensure that we iterate over a collection at most once, and that it makes sense to call our getfield
-equivalent function on a data type that it makes sense to do so.
Does this seem reasonable, or am I missing something?
Ah, right, I didn't take into account that we will need to iterate over collection multiple times. So yes, your solution looks like the best option.
Cool. I'll work on that today.
Attention: 8 lines
in your changes are missing coverage. Please review.
Comparison is base (
a82adaf
) 1.27% compared to head (7868498
) 0.74%.
Files | Patch % | Lines |
---|---|---|
src/trace.jl | 0.00% | 8 Missing :warning: |
:exclamation: Your organization needs to install the Codecov GitHub app to enable full functionality.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
@dfdx I've implemented the proposal -- I think we're good to go if you're happy
Perfect, thank you!
It turns out that while #48 didn't break any of the tests in Umlaut, it did break some examples of splatting in code that I'm testing on. This lead me to dig around a little, and construct the additional test cases for Umlaut that I've added in this PR.
Notably:
getfield
-- this is most things other thanTuple
andNamedTuple
.getindex
works, which is a greater array of things thangetfield
getindex
defined on them. I've created one of these by adding a test involving azip
, for whichgetindex
doesn't work.I'm not entirely sure what the right answer is here. I would really rather not revert to the
getindex
implementation ofunsplat!
, because it puts a (potentially) non-primitive on the tape, and doesn't handle all of the cases.One option is to, in the general case, insert a
Core._apply_iterate(itr, tuple, x...)
onto the tape, which will output aTuple
-- we could then employ the current strategy on thatTuple
(getfield
calls etc). This is a little bit annoying, but it at least has the property that it should handle every case. Downstream consumers of the tapes produced byUmlaut
could then handle whichever range of options they prefer. We could then add special handlers forTuple
s,Vector
s`, etc.What are your thoughts @dfdx ?