Open arogozhnikov opened 2 years ago
Thanks for the feedback! This is a good place to post it.
A rule of thumb I've been following is to always pass Dim objects as arguments, even when an operator is what constructs the first object with that dimension, which is the 'add one more argument approach':
axis = dims()
y1 = bincount(x, axis)
x2 = bincount(x, axis)
x1 + x2 # both have axis
It avoids having to rename dimensions when they are the same (though this can be done as well with y1.index(axis1, axis2)
). Similarly [stacking] (https://github.com/facebookresearch/torchdim/blob/de9785e58eb0b914638e942a65e3f74d2a695df9/torchdim/reference.py#L488) and split both take as argument the new dim(s) being added. There is still some work to be done to figure out the best API here. For instance maybe split
should take either single new Dim
(splits must have the same size), or a list of unsized Dims (split dims evenly but handle the case where they do not evenly split), or a list of sized Dims (fully customized way of splitting the dimension).
I believe this is similar to DimLists , a feature that is in the repo but I haven't document yet because I haven't worked out all of the corner cases. Like the example, their length is inferred by context, such as indexing. Once initialized then they behave just like a list of dimensions.
def z_score(x: Tensor, avg_over: List[Dim]):
m = x.mean(avg_over)
s = x.std(avg_over, unbiased=False) + 1e-5
return (x - m) / s
However, others can get a bit verbose, like an n-D convolution
def conv(lhs: Tensor, rhs: Tensor,
l_spatial: List[Dim], r_spatial: List[Dim], o_spatial: List[Dim],
b: Dim, c_in: Dim, c_out: Dim,
dilation: List[int] = None, stride: List[int] = None, expand: List[int] = None):
Delayed computation will take some experimenting. If it turns out that sum_product
is the only instance, then making it explicit helps remove cleverness with little loss to readability. There are other cases where delaying computation might be helpful as well. For instance, with indexing using dims, there ends up being a lot of pointwise cheap operators together to compute the index. It might make sense to delay and compile them with something like Triton (and generate their autograd). However, I recognize how important it is to make the behavior very predictable and wouldn't defer execution of operators across a long number of operators.
Description says you collect feedback, but not specifies how it should be provided, so I should open an issue ... I guess?
First, nice job @zdevito ! torchdim looks very promising, in particular indexing looks very friendly.
Unforeseen axes
Curious how you plan to implement operations that introduce a new axis, like boolean indexing or bincount / unique / set-like operations.
One possible way would be to return a new axis object along with result, but it has issues:
This can be solved by adding one more argument or by allowing manual 'coalescing' of axes.
Concatenation / chunking of named axis
Again, curious about your thoughts here.
Multi-axes
Cases when a single function should deal with tensors of several possible dimensionalities are frequent.
Potentially you can leave those problems to positional axes, but I'd recommend exploring the direction of multi-axes:
It has very 'pythonic' look, under the hood iterating over multi-axis would yield a single helper object, which would designate the position among other axes.
Delayed computations
It is a super-clever trick to delay multiplication until possible summation follows, but making it a single operation is more predictable
Just placing that in a function does not look worse to my eye, but open to other opinions
Calling functions
Can you provide more complete example here? It is unclear how
loss
function can take a matmul of images and w1, because it needs to sum overhidden
variable, but it was not passed to the function.More broadly, there should be some contract how callee interprets its inputs (from this example seems it deals only with non-named axes, and behavior of named axes is left to the calling function, but maybe I misunderstand). More examples with would be very helpful here.
Interaction with deep learning blocks
Can you explain how DL operations (e.g. convolution) would handle named dimensions (and would they)?
Add Dims context manager
Suggestion may sound a bit strange, but here is a rationale: if you don't have an axis object, you can't manipulate it, thus whole tensor becomes non-manipulatable.
I expect users would commonly return created objects without
order
-ing them first, and then deal with downstream problems (since those will be scalars for outer code, they will not error out in most operations, and then users will chase skippedorder
).Exit from contextmanager should deallocate all tensors that use axis objects created with context manager => more efficient memory management almost for free + in a large number of cases you can point user to the problem immediately.
Using Better Terminology
'Flattening and Splitting Dims'. Both terms are not suitable to the context. Yes, that's torch ops, but they become inappropriate as you move from discussing old-style ops to operations that are focused on axes. For instance, phrase 'flatten the dimensions' does not make any sense as dimensions/axes are already flat.
Einops uses terminology 'composition and decomposition of axes', because 1) it is obvious when you compose you get fewer axes 2) it hints that original content is preserved 3) wording: decomposition reverses composition, even kids know that (compare that to flatten vs split dims) 4) you can refer to 'composed axis' and 'composing axes', which is helpful in discussing code. Let's use this better terminology.