google-research / dex-lang

Research language for array processing in the Haskell/ML family
BSD 3-Clause "New" or "Revised" License
1.58k stars 107 forks source link

Strided index sets #309

Open oxinabox opened 3 years ago

oxinabox commented 3 years ago

It would be useful to have a stepped range index set. To allow one to index e.g. selecting every even or every odd position.

Right now to achieve that i end up having to do:

i. select ((ordinal i) `mod` 2)==0 x.i 0
i. select ((ordinal i) `mod` 2)==1 x.i 0
apaszke commented 3 years ago

Great suggestion! I've changed the name to "strided" since this is the more commonly used name I think (at least in the NumPy/Python world).

duvenaud commented 2 years ago

This is possible in user space now that #877 has landed. Here's one definition:

instance {n s offset} [Ix n] Ix (StridedIx n s offset)
  get_size            = \().
    a = ((size n) - offset)
    idiv a s + if (rem a s == 0) then 0 else 1
  ordinal             = \(MkStridedIx k). idiv ((ordinal k) - offset) s
  unsafe_from_ordinal = \i. MkStridedIx (unsafe_from_ordinal n (offset + i * s))

Now one can create e.g. even and odd indices:

def Evens (a:Type) [Ix a] : Type = StridedIx a 2 0
def Odds  (a:Type) [Ix a] : Type = StridedIx a 2 1

for i:(Evens (Fin 6)). i
[ (MkStridedIx (0@Fin 6))
, (MkStridedIx (2@Fin 6))
, (MkStridedIx (4@Fin 6)) ]@(StridedIx (Fin 6) 2 0)

for i:(Odds  (Fin 6)). i
[ (MkStridedIx (1@Fin 6))
, (MkStridedIx (3@Fin 6))
, (MkStridedIx (5@Fin 6)) ]@(StridedIx (Fin 6) 2 1)

Here's how to use this to loop over the relevant indices in a table:

full_table = for i:(Fin 10). ordinal i

for (MkStridedIx j):(StridedIx (Fin 10) 2 4).
  full_table.j
[4, 6, 8]@(StridedIx (Fin 10) 2 4)

However, there's a possible footgun: Below is what I wrote first, which typechecks but gives the wrong answer:

for i:(StridedIx (Fin 10) 2 4).
  full_table.((ordinal i)@_)
[0, 1, 2]@(StridedIx (Fin 10) 2 4)

One problem is that the word 'ordinal' has two plausible meanings in this context. Perhaps the general answer to this kind of problem is to train users to avoid using casting in general.

Another crazy idea is to allow datatypes with no constructors, but which are simply wrappers for another datatype. Then we should be able to write something like:

for i:(StridedIx (Fin 10) 2 4).
  full_table.i

Where i would have type n. I'm not sure if this is even a coherent suggestion, though.

apaszke commented 2 years ago

Well the ordinal solution does exactly what it says on the tin, and in some cases it might be what you want. Just not here. In any case, we should educate our users to avoid developing muscle memory of reaching for (ordinal i)@_ whenever they encounter a type error related to indexing, because it often can be a legit one!

About the final suggestion, I don't think having data types that work as you described is doable. You want them to behave exactly like the one they're derived from, except for the purpose of typeclass synthesis. But if you never even syntactically indicate where you expect a value to act like the derived type, then how is Dex supposed to know which typeclass to use?

One potential way to work around it would be to allow subtyping. Then, StridedIx would still be a different type, but it would be also accepted wherever its base index type is. The same situation would let us avoid explicit injections for (..i). Unfortunately, it's not quite as simple to integrate into our type system... but we might try! And until then, we might also try adding an Injectable typeclass for indices:

interface Injectable sub super
  inj : sub -> super

instance {n} {i:n} Injectable (..i) n
  inj = \x. %inject x

instance {n s off} Injectable (StridedIx n s off) n
  inj = \(MkStridedIx x). x

then, with your definitions, I get this:

full_table = for i:(Fin 10). ordinal i
for i:(StridedIx (Fin 10) 2 4). full_table.(inj i)
> [4, 6, 8]@(StridedIx (Fin 10) 2 4)