namedtensor / notation

108 stars 5 forks source link

Prototype Dex Implementation #17

Closed srush closed 3 years ago

srush commented 3 years ago

@apaszke convinced me that Dex can do named tensors even in its current form and provide type checking. It's pretty close. Here's a prototype that implements the current attention formulation using named tensors:

def attention (q: {heads:h & seq2:s2 & key: kt} => Float)
              (k: {heads:h & seq:s  & key: kt} => Float)
              (v: {heads:h & seq:s  & val: vt} => Float) : 
         ({ heads : h & seq2:s2 & val: vt} => Float) =
     q2 = ndim #seq q
     k2 = ndim #seq2 k
     v2 = ndim #seq2 v
     inner = (nfun #seq softmax) (ndot #key q2 k2)
     ndot #seq v2 (ndim #val inner)

The Dex formulation views names as record index types. It automatically generates functions of the form #seq that act as lenses for accessing these forms.

Their record syntax also lets you do things in roughly the same syntax we have been using. If I want to sum out heads:

def indexsum (q: {heads:h & seq2:s2 & key: kt} => Float) :                           
         ({seq2:s2 & key: kt} => Float) =                                            
    sum for i:h. q.{heads=i}

or alternatively

def indexsum (q: {heads:h & seq2:s2 & key: kt} => Float) : 
         ({seq2:s2 & key: kt} => Float) =
    (nred #heads sum)  q

The only thing I am stuck on (maybe @apaske knows the answer?) is whether this can do broadcasting? My current implementation manually expands extra dimensions through a ndim argument in order to line up the record types. Is there a nice way to get the union between two records automatically? (Particularly is there a version of ndot below where the as can be different types)

Here's my full implementation if you are interested. My implementation is very similar to @davidweichiang 's named and numbered style. pop pull out a vector dim and push puts it back.

def rename (name1: Iso a (b & c)) (name2: Iso d (b & c))
    (tensor: a => Float) : (d => Float) =

    for i: d.
        value = getAt name2 i
        old = popAt name2 i
        new = pushAt name1 value old
        tensor.new

def ndot (name: Iso a (b & c))
         (tensor1: a => Float )
         (tensor2: a => Float )
  : c => Float =
  for i : c.
     sum for j:b.
        newindex = pushAt name j i 
        tensor1.newindex  * tensor2.newindex

def push (name: Iso a (b & c))
         (tensor1: b => c => Float )
  : a => Float =
  for i : a.
      index1 = getAt name i
      index2 = popAt name i 
      tensor1.index1.index2

def pop (name: Iso a (b & c))
        (tensor1: a => Float )
  : b => c => Float =
  for i : b.
    for j : c.
      index = pushAt name i j
      tensor1.index 

def nred (name: Iso a (b & c))
         (fn : b => Float -> Float) :
         (a => Float -> c => Float) =
   \tensor.
     t2 = pop name tensor
     for j: c. fn (transpose t2).j

def nfun (name: Iso a (b & c))
         (fn : b => Float -> b => Float) :
         (a => Float -> a => Float) =
   \tensor.
     t2 = pop name tensor
     push name (transpose (for j: c. fn (transpose t2).j))

def ndim (name: Iso a (b & c)) 
         (tensor: c => Float) : (a => Float) =
    push name for i: b. tensor
boazbk commented 3 years ago

Nice! Will take a look (I do need to brush up on my non-existing Haskell :) )

I was also trying to think of this in terms of code as well, wrote my initial thoughts on https://hackmd.io/@boazbk/HyUg4D9iw

srush commented 3 years ago

Neat, I'll take a look.

The dex style is interesting. They really do treat indexing fully by record types, similar to the v1 proposal. There is no named dimension type (DID), simply a mapping from a name to a standard finite dimension type.

So the default would be:

for w in range(W): 
    for h in range(H):
        print(A[{width:w, height:h}])

You can do alternatively do:

for index in indexset({width: W, height:H}):
    print(A[index])

But as far as I can tell there is nothing in-between, i.e. this would not work without more explicit transformations.

for index in indexset({width: W}):
    print(A[index])

Although if that's the style we arrive at, I'm sure we could make it work.

srush commented 3 years ago

I am going to close this as I think the type system of Dex is different enough from what we are building that it would be hard to bridge the gap. Dex is neat, Named Tensors is neat, but they are different beasts.

oxinabox commented 3 years ago

Have you seen the ideal of Existential Dimensions? I got this idea 2rd hand via @jekbradbury from @dougalm, and idk on its current status for being able to do it in Dex. https://github.com/invenia/NamedDims.jl/issues/61 but i think it would be execelent to be able to do.

It solve the fact that various operations that should return a named dimension don't know what that name sure be. Like a multiply between a unnamed tensor and a named tensor gets one dimension with an unknown name, lets call that a existential name. Another example is the latent dimension from a matrix factorization which gives you two existential names that must be equal to each other on different arrays. But if you add two tensors, one fully named (Call it publicly named) and one with an existential name then now we know that that existential name must be equal to the that public name. So then you can do a kind of type -inference to propagate that name to every other existentially name that has to be the same as this one. and then if you end up while doing this trying to assign two different public names to the same existential name, then you throw an error as someone has done something invalid.

And then there is a fun extension for doing this with namespaces so you can have one public name per namespace. which i think if done write can let you deal with the fact that one library might call observations :obs, and another call them :times.