crutcher / tapestry

Experimental tensor application compiler suite.
https:tapestry/tree/main/docs
2 stars 1 forks source link

Metakernel Template Language Design #2

Open crutcher opened 8 months ago

crutcher commented 8 months ago

See as background: PolyhedralTypesAndIndexProjection

A metakernel is the symbolic execution pair of a kernel. Given inputs and parameters, it produces a graph description of the operation; the shape and types of the inputs and outputs, the sharding index spaces and index projections appropriate for the operation, and the constraints and requirements of the operation.

In the following example, AddKernel is a metakernel.

import org.tensortapestry.loom.graph.CommonEnvironments;
import org.tensortapestry.loom.graph.dialects.tensorops.TensorNode;
import org.tensortapestry.weft.dialects.tensorops.AddKernel;

static class Example {

  public static void main(String[] args) {
    var env = CommonEnvironments.expressionEnvironment();
    var graph = env.newGraph();

    var tensorA = TensorNode
      .builder(graph)
      .label("A")
      .configure(c -> c.dtype("int32").shape(10, 10))
      .build();

    var tensorB = TensorNode
      .builder(graph)
      .label("B")
      .configure(c -> c.dtype("int32").shape(10, 10))
      .build();

    var add = new AddKernel();

    var op = add
      .on(graph)
      .input("tensors", List.of(TensorSelection.from(tensorA), TensorSelection.from(tensorB)))
      .apply();

    graph.validate();
  }
}

Which would produce a graph fragment like the following: add example

This example metakernel is implemented in java; but it would be useful and powerful to be able to describe metakernels in a portable template language, as they'll frequently be provided by external client libraries.

Consider the following draft template for matmul:

matmul:
  index: "[$batch..., $in_row, $out_col]"

  constraints:
    dtype:
      enum:
        - int32
        - int64
        - float32
        - float64
        - complex64
        - complex128

  inputs:
    X:
      shape: "[$batch..., $in_row, $inner]"
      dtype: "$dtype"

      # this is shorthand for:
      # ipf:
      #   map: "[..., 1, 0]"
      #   shape: "[..., 1, $inner]"
      #
      # which is shorthand for:
      # ipf:
      #   map: "[ones($index.size - 2)..., 1, 0]"
      #   shape: "[ones($index.size - 2)..., 1, $inner]"
      #
      # Where the map unpacks to a diagonal matrix
      # with 1s on the prefix dimensions.
      ipf: "[..., 1, 0] :> [..., 1, $inner]"

    W:
      shape: "[$batch..., $inner, $out_col]"
      dtype: "$dtype"
      ipf: "[..., 0, 1] :> [..., $inner, 1]"

  outputs:
    result:
      # shape defaults to: "[$.index...]"
      # ipf defaults to: "[...] >: [...]"
      dtype: "$dtype"

The current weft metakernel language is an incomplete draft; and needs to be flushed out for how operations such as concat(ts, axis) and split(t, chunk_size, axis) are to be represented.

Alternatively, metakernels could be small javascript programs, and we could require a host environment to interpret them. This is the fallback design, and I want to put it off for as long as possible.

A good metakernel template language is an excellent base for a good graph rewrite language; and the goal is to be able to develop a graph rewrite template language as an extension.

crutcher commented 8 months ago

The core problem with axis/dim parameterized metakernels is that their index spaces and projections vary structurally with that parameter.

There are 3 common ways to handle this:

The broadcast dimensions complicate slice manipulation operations; and we could have a top-level "broadcast to batch" transform applied.

It might also be possible to say things like

  ipf: "[...[:$axis - 1], <axis-code>, ...[$axis+1:]] :> [...[:$axis-1], <axis-code>, ...[$axis+1:]]"
gabrielgrant commented 8 months ago

before getting into the core of what you're actually talking about here, going to nitpick a bit to ensure I'm understanding correctly

1.

A metakernel is the symbolic execution pair of a kernel

not sure I'm correctly understanding your use of the word "pair" -- i think what you mean by this is "a metakernel holds the information (or metadata) about a kernel needed for symbolic computation" ? is that correct? (and "pair" means something analogous to "twin" in "digital twin"?)

how would you differentiate this from kernel typing information? i think a metakernel describes a superset of what's usually considered type info? the polyhedral type info/model?

2.

it produces a graph description of the operation

can you clarify how you're using "kernel" vs "operation" (and "op", which i presume is synonymous?). i think of general usage in tensor frameworks being that an op is the symbolic representation, and a kernel is a specific, concrete implementation of that op (usually specific to a given hardware/execution backend)

3.

not sure if you have a term for a specific, concrete sharding of a kernel/op? (i don't know of a term for this from elsewhere)

4.

to maintain consistency with the terminology you've used later, i think this could be "it produces a graph fragment describing the operation"? or is the graph you're talking about here different from the graph fragment you mention later?

5.

perhaps a silly question, but what does "ipf" stand for?

6.

I'm having trouble understanding the problem you're describing in the comment. perhaps an example would help me get it into my thick skull :P

crutcher commented 8 months ago

A metakernel is to a kernel as symbolic execution is to execution.

In a programming language without spatial structure in the types, symbolic execution need mainly keep track of a call graph. But in loom, we care about the shape of the result of an operation; so a metakernel not only validates that the input names and types were legal; but then generates the resulting operation output shapes and projections.

We describe them as pairs, because they are equivalent in their respective domains.

We can argue that a given kernel, from tapestry's perspective, is a strongly typed contract for how some code should work, and not actually the implementation of that code; such that we could have implementations of a kernel for different targets. But it isn't really necessary to make that distinction, because as far as tapestry is concerned, a kernel is just the label on some external and opaque code.

So a kernel is the opaque identifier for an external implementation of an algorithm; and a metakernel is a program which will take symbolic structural descriptions of inputs to a kernel, validate their structure, and produce symbolic structural descriptions of outputs from that kernel, without calling it.

An operation (as opposed to an operator) is used in loom to specify a specific act of applying a transform in the graph. Many operations may share the same kernel; and an operation may be shardded into many application shards.

operation was picked in opposition to operator; we're describing the specific instance, not the generic function.

"ipf" is shorthand for "IndexProjectionFunction", it is the projection from the index space of the operation to the index space of the tensor; a core idea in polyhedral type theory.

See: https://crutcher.github.io/Tapestry/#Index-Projection-Functions See: https://github.com/crutcher/loom/blob/main/tensortapestry-loom/README.md#polyhedral-type-theory-sharding

The issue is that the structure of the projection function differs for different values of $axis for operations such as sum(tensor, axis) and concat(tensors, axis); and the template language needs a way to express that variation.

crutcher commented 8 months ago

Consider the case where we wish to sum over a dimension; and ignore multi-stage hierarchal reduction kernels for now; so everything is happening on one kernel.

Suppose we have sum(Tensor tensor, int axis) which sums over the given axis.

For input shape [a, b, c, d]:

Now, keep in mind that our metakernel expressions/patterns for inputs are patterns, we use them to match variable names to the input shapes we're provided; while the expressions for the structure of the index space, the shape of the output tensors, and all index projection functions are expressions; computed in terms of other values, not match patterns.

generally, we'll find that for simple operations the operation index will match the shape of the output tensor, and the projection from the index to the output can be identity.

This identity projection can be expressed as "D[1...]" for the output ipf.

There are several layers to the unpack of this expression.

  1. ... expands its value to fill the list it is in up to the context size, so this says "D[1, 1, 1, ...]" up to some target.
  2. The context size for a tensor is the size of the coordinate space of that tensor; so in this case 3 => "D[1, 1, 1]"
  3. A diagonalized representation of an IPF is assumed to have zero offset, and an attached unit shape, unless otherwise specified => "D[1, 1, 1] + [0, 0, 0] :> [1, 1, 1]"
  4. A diagonalized IPF is transformed into the associated matrix representation:
    {
    "affineMap": {
    "projection": [
      [1, 0, 0],
      [0, 1, 0],
      [0, 0, 1]
    ],
    "offset": [0, 0, 0]
    },
    "shape": [1, 1, 1]
    }

The construction of the index space, and the projection for the input, are less straightforward.

Suppose we gave the input tensor the pattern "[$shape...]"; saying we wished to bind the variable $shape to the entire input shape.

To construct an index space, which does one unit of work per $axis axis, we need some way to slice that shape; preferably using variations of the same shorthand from above.

Consider an index space range of: "[$shape[:$axis - 1]..., $shape[$axis + 1:]...]"

This expression has two invocations of ..., but we can disambiguate this by saying:

So, for axis=2;

  1. "[a, b, c, d]" would match to $shape = [a, b, c, d];
  2. "[[a, b, c, d][:2 - 1] ..., [a, b, c, d][2 + 1:]...]";
  3. "[[a, b]..., [d]...]"
  4. "[a, b, d]"

To construct the ipf for the input tensor (where we're trying to map index space [a, b, d] to input shape [a, b, c, d], and perform axis operations on axis=2/c; we could say:

"[1...[:$axis]..., 0, 1...[$axis + 1:]...] :> [1...[:$axis]..., $shape[$axis], 1...[$axis + 1:]...]"

This would require us to interpret a scalar expansion in the context of a slice applied to it for length; but if we permit that; we could unpack the above:

  1. "[1...[:$axis]..., 0, 1...[$axis + 1:]...] :> [1...[:$axis]..., $shape[$axis], 1...[$axis + 1:]...]"
  2. "[1...[:2]..., 0, 1...[2 + 1:]...] :> [1...[:2]..., [a, b, c, d][2], 1...[2 + 1:]...]"
  3. "[1...[:2]..., 0, 1...[3:]...] :> [1...[:2]..., c, 1...[3:]...]"
  4. "[[1, 1]..., 0, [1]...] :> [[1, 1]..., c, [1]...]"
  5. "[1, 1, 0, 1] :> [1, 1, c, 1]"

Here, we run into a problem. The index space has 3 dims; the input tensor has 4.

We could distort the index space with a 1 dimension, at the cost of making the output IPF weirder. We'd like a rule to cleanly say that we're adding a dimension, but still use the diagonal form; because we'd like a way to go from the short form to this:

{
  "affineMap": {
    "projection": [
      [1, 0, 0],
      [0, 1, 0],
      [0, 0, 0],
      [0, 0, 1]
    ],
    "offset": [0, 0, 0, 0]
  },
  "shape": [1, 1, "c", 1]
}
crutcher commented 8 months ago

A grammar like this brings us close to the above:

grammar F;

prog
   : expr EOF
   ;

expr
   : lhs = expr op = (TIMES | DIV | MOD) rhs = expr # BinOpExpr
   | lhs = expr op = (PLUS | MINUS) rhs = expr # BinOpExpr
   | lhs = expr op = POW rhs = expr # BinOpExpr
   | op = MINUS e = expr # NegateExpr
   | expr ELLIPSIS # EllipsisExpr
   | expr select # SelectExpr
   | LBRACKET expr (COMMA expr)* RBRACKET # ListExpr
   | LPAREN e = expr RPAREN # ParensExpr
   | atom # AtomExpr
   ;

select
   : LBRACKET expr RBRACKET # IndexSelect
   | LBRACKET expr COLON RBRACKET # SliceToSelect
   | LBRACKET COLON expr RBRACKET # SliceFromSelect
   | LBRACKET expr COLON expr RBRACKET # SliceSelect
   ;

atom
   : val=integer # NumberAtom
   | id=variable # VariableAtom
   ;

integer
    : INTEGER_LITERAL
    ;

variable
   : DOLLAR id=qual_id
   ;

qual_id
   : ID (DOT ID)*
   ;
crutcher commented 8 months ago

Alternatively, we could say that there are shape and ipf transforms.

Something like, "where you could put a shape or ipf, you can instead put a transform function";

so, say,

ipf:
  transform: insert
  index: "$axis"
  source: "[1...]"
  patch: "[0] :> [$shape[$axis]]"

Which is a more focused problem; but still needs a way to decide how to build the projection matrix for "[0]", or some variation.

crutcher commented 8 months ago

An affine projection, such as this:

{
  "projection": [
    [1, 0, 0],
    [0, 1, 0],
    [0, 1, 2]],
  "offset": [3, 4, 5]
}

Is sometimes written as an augmented matrix:

\left( 
    \begin{array}{ccc|c}
        1 & 0 & 0 & 3 \\
        0 & 1 & 0 & 4 \\
        0 & 1 & 2 & 5
    \end{array}
\right)

We could squint and see our index range projection functions as doubly augmented matricies, such that this:

{
  "affineMap": {
    "projection": [
      [1, 0, 0],
      [0, 1, 0],
      [0, 1, 2]],
   "offset": [3, 4, 5]
  },
  "shape": [1, 1, 5]
}

Is equivalent to this:

\left( 
    \begin{array}{ccc|c|c}
        1 & 0 & 0 & 3 & 1 \\
        0 & 1 & 0 & 4 & 1 \\
        0 & 1 & 2 & 5 & 5
    \end{array}
\right)

And it might be easier to think in terms of inserting or replacing rows on this structure.

ipf:
  transform: insertRow
  source: "[1...]"
  index: "$axis"
  row: "[0... | 0 | $shape[$axis] ]"
crutcher commented 7 months ago

I've been iterating on an expanded pattern language

Suppose you have a list of tensors, with shapes:

"tensors": [
  [100, 128, 256, 512, 2],
  [100, 128, 256, 512, 4]
],
"masks": [
  [256, 512, 2],
  [256, 512, 4]
]

And you want to parse them out into dimension shapes, as a structure. Suppose we had machinery to apply shape patterns:

"tensors": "[$batch..., $shape=($height, $width), $features[$i]]",
"masks": "[$shape=($height, $width), $features[$i]]"

and applying it to the above would yield:

$batch=[100, 128]
$shape=[256, 512]
$height=256
$width=512
$i=2
$features=[2, 4]

But only if there was a consistent constant value for each named component of the pattern. I have a previous version without the support for indexed dimension values; and I've got the expression language parser implemented for this; but not yet the dim group matcher.

crutcher commented 7 months ago

Another viable approach would be to have template families; where each template family had different replacement rules, if we couldn't easily construct a single template family which was sufficiently general.