facebookresearch / torchdim

Named tensors with first-class dimensions for PyTorch
BSD 3-Clause "New" or "Revised" License
322 stars 12 forks source link

Feedback #3

Open arogozhnikov opened 2 years ago

arogozhnikov commented 2 years ago

Description says you collect feedback, but not specifies how it should be provided, so I should open an issue ... I guess?

First, nice job @zdevito ! torchdim looks very promising, in particular indexing looks very friendly.

Unforeseen axes

Curious how you plan to implement operations that introduce a new axis, like boolean indexing or bincount / unique / set-like operations.

One possible way would be to return a new axis object along with result, but it has issues:

x1, axis1 = bincount(x)
x2, axis2 = bincount(x)
x1 + x2 # they have different axes or same axis?

This can be solved by adding one more argument or by allowing manual 'coalescing' of axes.

Concatenation / chunking of named axis

Again, curious about your thoughts here.

Multi-axes

Cases when a single function should deal with tensors of several possible dimensionalities are frequent.

Potentially you can leave those problems to positional axes, but I'd recommend exploring the direction of multi-axes:

(Q[b, qaxes, [head, c]] * K[b, kaxes, [head, c]]).sum(c).order(b, *qaxes, *kaxes, head)

# * not allowed in indexing
(Q.index(b, *qaxes, [head, c]) * K.index(b, *kaxes, [head, c])).sum(c).order(b, *qaxes, *kaxes, head)

It has very 'pythonic' look, under the hood iterating over multi-axis would yield a single helper object, which would designate the position among other axes.

Delayed computations

It is a super-clever trick to delay multiplication until possible summation follows, but making it a single operation is more predictable

x = a * b
result = (x * c).sum(i, j) # here einsum-ification probably happens

x + 1 #  user actually expected that one to be materialized.

Just placing that in a function does not look worse to my eye, but open to other opinions

sum_product([i, j], a, b, c)

Calling functions

batch, inputs, hidden, classes = dims(4)
print(loss(w1[inputs, hidden], w2[hidden, classes], images[batch, inputs], labels[batch]))

Can you provide more complete example here? It is unclear how loss function can take a matmul of images and w1, because it needs to sum over hidden variable, but it was not passed to the function.

More broadly, there should be some contract how callee interprets its inputs (from this example seems it deals only with non-named axes, and behavior of named axes is left to the calling function, but maybe I misunderstand). More examples with would be very helpful here.

Interaction with deep learning blocks

Can you explain how DL operations (e.g. convolution) would handle named dimensions (and would they)?

Add Dims context manager

with dims(6) as (h2, w2, c, b, h, w):
    <computations>

Suggestion may sound a bit strange, but here is a rationale: if you don't have an axis object, you can't manipulate it, thus whole tensor becomes non-manipulatable.

I expect users would commonly return created objects without order-ing them first, and then deal with downstream problems (since those will be scalars for outer code, they will not error out in most operations, and then users will chase skipped order).

Exit from contextmanager should deallocate all tensors that use axis objects created with context manager => more efficient memory management almost for free + in a large number of cases you can point user to the problem immediately.

Using Better Terminology

'Flattening and Splitting Dims'. Both terms are not suitable to the context. Yes, that's torch ops, but they become inappropriate as you move from discussing old-style ops to operations that are focused on axes. For instance, phrase 'flatten the dimensions' does not make any sense as dimensions/axes are already flat.

Einops uses terminology 'composition and decomposition of axes', because 1) it is obvious when you compose you get fewer axes 2) it hints that original content is preserved 3) wording: decomposition reverses composition, even kids know that (compare that to flatten vs split dims) 4) you can refer to 'composed axis' and 'composing axes', which is helpful in discussing code. Let's use this better terminology.

zdevito commented 2 years ago

Thanks for the feedback! This is a good place to post it.

Unforeseen Axes

A rule of thumb I've been following is to always pass Dim objects as arguments, even when an operator is what constructs the first object with that dimension, which is the 'add one more argument approach':

axis = dims()
y1 = bincount(x, axis)
x2 = bincount(x, axis)
x1 + x2 # both have axis

It avoids having to rename dimensions when they are the same (though this can be done as well with y1.index(axis1, axis2)). Similarly [stacking] (https://github.com/facebookresearch/torchdim/blob/de9785e58eb0b914638e942a65e3f74d2a695df9/torchdim/reference.py#L488) and split both take as argument the new dim(s) being added. There is still some work to be done to figure out the best API here. For instance maybe split should take either single new Dim (splits must have the same size), or a list of unsized Dims (split dims evenly but handle the case where they do not evenly split), or a list of sized Dims (fully customized way of splitting the dimension).

zdevito commented 2 years ago

Multi-axes

I believe this is similar to DimLists , a feature that is in the repo but I haven't document yet because I haven't worked out all of the corner cases. Like the example, their length is inferred by context, such as indexing. Once initialized then they behave just like a list of dimensions.

zdevito commented 2 years ago
def z_score(x: Tensor, avg_over: List[Dim]):
     m = x.mean(avg_over)
     s = x.std(avg_over, unbiased=False) + 1e-5
     return (x - m) / s

However, others can get a bit verbose, like an n-D convolution

def conv(lhs: Tensor, rhs: Tensor,
         l_spatial: List[Dim], r_spatial: List[Dim], o_spatial: List[Dim],
         b: Dim, c_in: Dim, c_out: Dim,
         dilation: List[int] = None, stride: List[int] = None, expand: List[int] = None):
zdevito commented 2 years ago

Delayed computation will take some experimenting. If it turns out that sum_product is the only instance, then making it explicit helps remove cleverness with little loss to readability. There are other cases where delaying computation might be helpful as well. For instance, with indexing using dims, there ends up being a lot of pointwise cheap operators together to compute the index. It might make sense to delay and compile them with something like Triton (and generate their autograd). However, I recognize how important it is to make the behavior very predictable and wouldn't defer execution of operators across a long number of operators.