Open ricardoV94 opened 2 years ago
Uff, this looks bad.
This makes me wish we had dimension objects even more, I think that would solve this problem? But I don't see how we could introduce that as a strict requirement without breaking backward compatibility...
If each dimension had an object associated, and we only allow broadcasting if those object are identical, we would always know when we create nodes how broadcasting would work, even before we know if a dimension happens to have length 1.
So something like:
dim0 = at.Dimension(name="dim0")
dim1 = at.Dimension(name="dim1")
dim2 = at.Dimension(name="dim2")
x = at.dvector("x", dims=(dim0,))
y = at.dvector("y", dims=(dim1,))
x + y # dims == (dim0, dim1)
x = at.dvector("x", dims=(dim0,))
y = at.dvector("y", dims=(dim0,))
x + y # dims=(dim0,)
X = at.dmatrix("X", dims=(dim0, dim1))
Y = at.dmatrix("Y", dims=(dim0, dim2))
X + Y # dims == (dim0, dim1, dim2)
That way we'd always statically know the result shape/dims of broadcasting ops.
This makes me wish we had dimension objects even more, I think that would solve this problem?
It looks like the dimension objects you're talking about would provide essentially the same information/accomplish the same things as specify_shape
(alongside the existing shape inference). Perhaps the main difference is that these dimension objects are implicitly intended to work at the Type
level? If so, the advantages of that approach aren't apparent at the moment.
No, it is a pretty different idea than specify_shape
. specify_shape
just fixes the shape to known integer values, it provides no information if two different arrays that happen to have the same shape actually refer to the same dimension.
By dimension
I mean some constant thing that has a clear identity and might for instance also have associated coordinates.
So if we have dimensions time
and country
, and x.dims == (time,); y.dims == (country,)
, those two axes could never broadcast together, even if by accident we had the same number of time points as countries, because their respective dimensions refer to different, incompatible things.
And if we have two arrays with the same dimension, we still wouldn't know the shape, but only that those arrays have the same, and a conceptually compatible dimension. A type like Array<dims=(ountry,)>
would mean that we are providing one value for each country (whatever number of countries we have). It never makes sense to directly add this to Array<dims=('time')>
, ie "one value for each point in time", even if n_time == n_countries
. So for instance we can redefine the sum as a broadcasting sum with output Array<dims=('time', 'country')>
.
This is essentially how broadcasting works in xarray, and in my opinion it is way cleaner, prevents lots of bugs and is in general a joy to work with, compared with the relatively messy numpy behavior, where for instance all hell can break lose if it just so happens that one of your dimensions has length 1. It does require a bit of acclimatization though. :-)
Edit But I guess unless we can find a way to solve the compatibility problems, I'm not sure this is the right place to discuss this I guess.
specify_shape
just fixes the shape to known integer values, it provides no information if two different arrays that happen to have the same shape actually refer to the same dimension.
specify_shape
can be used to assign the same non-constant Variable
to arbitrary dimensions in the shapes of multiple Variable
s. Combine this with shape inference and it's possible to determine that distinct Variable
s share exactly the same shapes in those dimensions fixed by specify_shape
.
import aesara
import aesara.tensor as at
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt_utils import optimize_graph
from aesara.tensor.basic_opt import ShapeFeature
shape_dim_1 = at.lscalar("time")
shape_dim_2 = at.lscalar("country")
x = at.vector("x")
x = at.specify_shape(x, (shape_dim_1,))
y = at.matrix("y")
y = at.specify_shape(y, (shape_dim_1, shape_dim_2))
shape_feature = ShapeFeature()
fg = FunctionGraph(outputs=[x.shape[0], y.shape[0]], features=[shape_feature])
aesara.dprint(fg)
# Subtensor{int64} [id A] 2
# |Shape [id B] 1
# | |SpecifyShape [id C] 0
# | |x [id D]
# | |time [id E]
# |ScalarConstant{0} [id F]
# Subtensor{int64} [id G] 5
# |Shape [id H] 4
# | |SpecifyShape [id I] 3
# | |y [id J]
# | |time [id E]
# | |country [id K]
# |ScalarConstant{0} [id L]
fg_opt = optimize_graph(fg)
aesara.dprint(fg_opt)
# time [id A]
# time [id A]
shape_feature.shape_of
# {x: (Shape_i{0}.0,),
# time: (),
# SpecifyShape.0: (time,),
# Shape.0: (TensorConstant{1},),
# ScalarConstant{0}: (),
# Subtensor{int64}.0: (),
# y: (Shape_i{0}.0, Shape_i{1}.0),
# country: (),
# SpecifyShape.0: (time, country),
# Shape.0: (TensorConstant{2},),
# ScalarConstant{0}: (),
# Subtensor{int64}.0: (),
# MakeVector{dtype='int64'}.0: (TensorConstant{2},),
# InplaceDimShuffle{}.0: (),
# MakeVector{dtype='int64'}.0: (TensorConstant{1},)}
The main distinction I see here is the desire to refer to dimensions independently from their shapes, but that's not an entirely meaningful distinction in this symbolic context, because a symbolic indexed shape can serve as a direct proxy for a dimension.
Oh, that's neat, I didn't realize you could put in variables! I'll have to play with this and see where I can get with it.
The different broadcasting behavior still does seem like a different issue though, doesn't it?
Oh, that's neat, I didn't realize you could put in variables! I'll have to play with this and see where I can get with it.
There are a lot of places where we could/should be using shape inference, but we don't, so, while this stuff is possible, it's not necessarily used/as integrated as one would expect.
The different broadcasting behavior still does seem like a different issue though, doesn't it?
I'm not entirely clear on this situation yet; I only mentioned specify_shape
and shape inference because, when combined, they should—theoretically—allow us to determine (non-)equivalent shapes when static/Type
-level information isn't available.
It seems like we might need a new Op that unbroadcasts (reduces) arrays to a given shape or leaves the input unchanged.
Check https://mostafa-samir.github.io/auto-diff-pt2/#unbroadcasting-adjoints
Something like:
x = at.matrix("x")
# at.reduce_to is probably a better name
# maybe sum is all we will ever need
y1 = at.unbroadcast_to(x, shape=(1, 5), reduce_op="sum”)
y1.eval({x: np.ones((5, 5))}) # [[5, 5, 5, 5, 5]]
# This won't do anything, but shape may be only
# known at runtime, as in the example in this issue!
y2 = at.unbroadcast_to(x, shape=(5, 5))
y2.eval({x: np.ones((5, 5))}) # np.ones((5, 5))
# If the shape is not compatible with something that could
# have been broadcasted to the input shape, an error is raised
y3 = at.unbroadcast_to(x, shape=(2, 5))
y3.eval({x: np.ones((5, 5))}) # ValueError
This was also brought up by @sayam753 and @purna135 in Slack in relation to their work on batched solve where dynamic unbroadcasting gradients also crops up. It was that discussion that led me to suspect of this bug!
Edit: This may be possible already without a specialized Op, if sum allows for symbolic axis? Does it?
In that case we could cook a helper pretty quickly, and perhaps add some rewrite in case the axis are constant folded during compilation/ and a sum with constant axis is more efficient.
Edit: Sum does not allow for variable axis
Edit: mentioned other related issues in the top comment.
Let's try to move to the Blockwise
form of this problem/situation. That way, we can attempt to make some progress on https://github.com/aesara-devs/aesara/issues/695 (e.g. finish implementing Blockwise.Lop
in https://github.com/aesara-devs/aesara/pull/757) and address these issues (via inheritance from the Elemwise
case).
This is also likely a problem in the grad of
BroadcastTo
which callsinfer_broadcastable
and which defaults to assuming something will not have broadcasted if a static shape of 1 can't be inferred.
Yeah, it looks like we might need to add symbolic conditions for those dimensions and let them be simplified later via shape inference and/or constant folding. This is similar to what we do when constructing broadcasted shape graphs.
Maybe it would help if we added an op that precisely computes what actually happens when we broadcast several arrays?
So a wrapper around something like this (I really hope this can be simplified a bit, though). For known shapes this could be constant-folded.
def broadcast_patterns_numpy(*shapes):
"""Given input shapes predict broadcasting behavior.
Returns
-------
- output_shape: List[int]
The shape of the broadcasted array
- did_broadcast: List[List[bool]]
For each input and each output dimension: True, if the input
was broadcasted in this particular axis
- dim_index: List[List[Optional[int]]]
For each input and each output dimension: An index that indicates
which axis of the input array corresponds to the output axis. If
the input did not have a axis, None
"""
out_rank = max(len(shape) for shape in shapes)
shapes_rev = [shape[::-1] for shape in shapes]
# by_rank[i] contains the lengths of the
# input dimensions in position i, counted
# from the back
by_rank = [[] for _ in range(out_rank)]
for shape_rev in shapes_rev:
i = -1
for i, length in enumerate(shape_rev):
by_rank[i].append(length)
for j in range(i + 1, out_rank):
by_rank[j].append(None)
does_broadcast_rev = [[] for _ in shapes]
dim_index_rev = [[] for _ in shapes]
output_shape_rev = []
for rank, lengths in enumerate(by_rank):
lengths_non_zero = [length for length in lengths if length is not None and length != 1]
if not lengths:
out_length = 1
else:
out_length = lengths_non_zero[0]
for length in lengths_non_zero[1:]:
if length != out_length:
raise ValueError(f"Could not broadcast at rev rank {rank}: {lengths}")
output_shape_rev.append(length)
for input, length in enumerate(lengths):
if length is None:
does_broadcast_rev[input].append(True)
dim_index_rev[input].append(None)
else:
does_broadcast_rev[input].append(length == 1 and out_length != 1)
n_before = sum(idx is not None for idx in dim_index_rev[input])
dim_index_rev[input].append(len(shapes[input]) - n_before - 1)
does_broadcast = [bools[::-1] for bools in does_broadcast_rev]
dim_index = [idxs[::-1] for idxs in dim_index_rev]
output_shape = output_shape_rev[::-1]
return output_shape, does_broadcast, dim_index
For example if we broadcast arrays of shape ()
, (2, 5, 0)
and (1, 5, 1)
we'd get
out_shape, did_broadcast, dim_index = broadcast_patterns_numpy((), (2, 5, 0), (1, 5, 1))
out_shape == [2, 5, 0]
did_broadcast == [
[True, True, True], # The first argument was broadcast in all axes
[False, False, False], # The second argument wasn't broadcasted at all
[True, False, True] # The third argument was broadcasted in the first and last axis
]
dim_index == [
[None, None, None], # New axes where created for all output axis for the first arg
[0, 1, 2], # For the second arg all axes were used and their order didn't change
[0, 1, 2]
]
The dim_index
is a little redundant in numpy broadcasting, but it would be really helpful in xarray broadcasting. ;-)
So a wrapper around something like this (I really hope this can be simplified a bit, though). For known shapes this could be constant-folded.
The function I linked to in my previous comment already computes broadcast shapes.
For example if we broadcast arrays of shape (), (2, 5, 0) and (1, 5, 1) we'd get
What's up with that shape of (2, 5, 0)?
I think that combo is invalid as per numpy broadcasting rules due to the zero?
Maybe it would help if we added an op that precisely computes what actually happens when we broadcast several arrays?
That sounds valid, but it seems a more convoluted answer if you ONLY want to fix this problem.
In these cases we have the input and broadcasted gradient output, so the only thing we need is to reduce that gradient along the dimensions that where of size 1 in the input.
Actually, in the Elemwise case we don't even have to worry about new dims, because make_node adds Dimshuffles to the inputs to align the number of dims (but we will perhaps remove that one day)
So what we need is just:
# perform method of new Op boils down to this
def unbroadcast_to(x, shape):
axis_to_sum = [
i
for i, (s, xs) in enumerate(zip(shape, x.shape))
if s==1 and xs !=1
]
if not axis_to_sum:
return x
return np.sum(x, axis=axis_to_sum, keepdims=True)
In the grad where this issue would crop up we would do something like
grad = unbroadcast_to(bcast_grad, shape=input.shape)
And then we could have some rewrites to try to get rid of this Op during compilation. For instance:
1) if the target shape and the input shapes are known to be equivalent, we can remove the Op
2) if the target shape has some (constant-folded) 1
s we can replace the original input by one where we summed the known 1
s dimensions already.
1) Hopefully there's already a rewrite to get rid of useless sum along dimensions of size 1. If not, we can add one.
3) if the target shape has no (constant-folded) 1
s we can remove the Op (or replace by some Asserts related to the shape if we want to)
Yeah, it looks like we might need to add symbolic conditions for those dimensions and let them be simplified later via shape inference and/or constant folding. This is similar to what we do when constructing broadcasted shape graphs.
The problem is that without an Op like the one I sketched, the only safe thing to do when you can't be sure ahead of time if something will have had a shape of 1 (or certainly not 1) is to raise in the grad method.
If IfElse allowed for different shapes in the two branches we could also write a symbolic graph that applies the needed logic, but from one of the open issues it seems that both branches must have the same shape.
Allowing sum to have symbolic axis (as long as keepdims is used, this should be fine for Aesara) would also allow for a simple solution without new Ops. But maybe that would raise a whole new set of problems
The problem is that without an Op like the one I sketched, the only safe thing to do when you can't be sure ahead of time if something will have had a shape of 1 (or certainly not 1) is to raise in the grad method.
Simply put, if we don't have the information at compile time, then it needs to be handled at run-time.
The latter is the only scenario in which an Op
would be helpful; however, I have yet to see why a new Op
is necessary for anything in this scenario. The reason(s) why a new Op
is completely necessary also need to be very clear in order to merge anything that takes such an approach.
Additionally, the description of this issue needs to clarify which result is correct and why.
Additionally, the description of this issue needs to clarify which result is correct and why.
The gradients should have the same shape of the inputs, so the case where row is used is correct.
This issue will arise for any Op that may or not broadcast its inputs at runtime. If broadcast occurs you need to sum the gradient across the broadcasted dimensions, otherwise you should not. However Aesara does not provide any building blocks that can achieve this branch logic AFAICT.
Note that explicitly broadcasting all the inputs (like the explicit Dimshuffles introduced by Elemwise) wouldn't fix this either. The gradient of BroadcastTo shares the same limitations of Elemwise.
However Aesara does not provide any building blocks that can achieve this branch logic AFAICT.
It does.
However Aesara does not provide any building blocks that can achieve this branch logic AFAICT.
It does.
Via what?
To be clear, we need something that can do the following.
def foo(x, y):
...
x = at.matrix("x")
y = np.random.normal(size=(5, 5))
f = aesara.function([x], foo(x, y)
assert f(np.ones((1, 5))) == np.sum(y, axis=0, keepdims=True)
assert f(np.ones((5, 1))) == np.sum(y, axis=1, keepdims=True)
assert f(np.ones((1, 1))) == np.sum(y, axis=(0, 1), keepdims=True)
assert f(np.ones((5, 5))) == y
About the shape 0: Numpy allows that actually (although I think that might be a design flaw and I think this complicates things a bit, but I haven't thought it through):
In these cases we have the input and broadcasted gradient output, so the only thing we need is to reduce that gradient along the dimensions that where of size 1 in the input.
Is this actually enough? We still need to know if keepdims=True
or keepdims=False
, right? But I guess we can use the rank of the arrays for that? I actually thought at first we'd also need to know if broadcasting did in fact occur, but I guess an extra sum(keepdims=True)
doesn't hurt if the dimension really has length 1.
About the shape 0: Numpy allows that actually (although I think that might be a design flaw and I think this complicates things a bit, but I haven't thought it through):
Did not expect that! I don't think that's differentiable anyway :P
Is this actually enough? We still need to know if
keepdims=True
orkeepdims=False
, right? But I guess we can use the rank of the arrays for that?
Yeah extra dims is easy to know. In the Elemwise
case, for now, its always a case of keepdims=True
because Elemwise
adds new dims to the inputs in make_node
via Dimshuffles
automatically, so the loss of dims is handled by the grad of Dimshuffle
. But we can try to make this more useful and also account for those cases.
I actually thought at first we'd also need to know if broadcasting did in fact occur, but I guess an extra
sum(keepdims=True)
doesn't hurt if the dimension really has length 1.
Yeah I assumed that wouldn't hurt, but we can be more clever about it if needed.
Did not expect that! I don't think that's differentiable anyway :P
Yes it is, the derivative is just an empty array. :-)
I wrote a couple of tests, maybe we can convert those into pytest.parametrized... (and maybe check the grad numerically)
Currently, I get test failures for those input shapes:
[(1,), (2,)] Grad 0 should be (1,) but is (2,)
[(0,), (1,)] Grad 1 should be (1,) but is (0,)
[(2, 1), (2, 2)] Grad 0 should be (2, 1) but is (2, 2)
[(2, 3, 4, 1), (1, 3, 1, 5), (5,)] Grad 0 should be (2, 3, 4, 1) but is (2, 3, 4, 5)
[(0, 3, 4, 1), (1, 3, 1, 5), (5,)] Grad 0 should be (0, 3, 4, 1) but is (0, 3, 4, 5)
[(2, 3), (), (1, 1, 3)] Elemwise{add,no_inplace}.grad returned a term with 3 dimensions, but 2 are required.
Found another error message:
(
[(1, 1, 1), (), (2, 1)],
[np.float32, np.float32],
(1, 2, 1),
np.float32,
),
[(1, 1, 1), (), (2, 1)] Cannot drop a non-broadcastable dimension: (True, False, True), []
Do these errors come from known_shapes=False
?
Most do. However these also happen with know_shapes=True
:
[(2, 3), (), (1, 1, 3)] Elemwise{add,no_inplace}.grad returned a term with 3 dimensions, but 2 are required.
[(1, 1, 1), (), (2, 1)] Cannot drop a non-broadcastable dimension: (True, False, True), []
[(1, 1), (), (2,)] Cannot drop a non-broadcastable dimension: (True, False), []
The "cannot drop non-broadcastable dimension" sounds like another legacy of treating unknown dims as != 1, but I am surprised that it shows up here.
The "expected 2 but got 3" I am not sure.
@ricardoV94 You mean something like this with broadcast_to
?
class SumToShape(at.Op):
__props__ = ()
def make_node(self, tensor, target_shape):
dtype = tensor.type.dtype
target_shape = at.as_tensor_variable(target_shape)
#output = type(tensor.type)(dtype=dtype, shape=target_shape)() # Why doesn't this work?
output = type(tensor.type)(dtype=dtype, shape=[None for _ in target_shape])()
return aesara.tensor.basic.Apply(self, [tensor, target_shape], [output])
def perform(self, node, inputs, outputs):
(tensor, target_shape) = inputs
(output,) = outputs
leading_dims = len(tensor.shape) - len(target_shape)
assert leading_dims >= 0
sum_axis = [True] * leading_dims
for actual_length, target_length in zip(tensor.shape[leading_dims:], target_shape):
not_equal = actual_length != target_length
if not_equal:
assert target_length == 1
sum_axis.append(not_equal)
sum_axis = np.nonzero(sum_axis)[0]
output_value = np.sum(tensor, axis=tuple(sum_axis), keepdims=True)
output_value = output_value[(0,) * leading_dims]
output[0] = output_value
_SumToShape = SumToShape()
def sum_to_shape(value, shape):
return _SumToShape(value, shape)
x = at.tensor(shape=(None, None, None), dtype=np.float32)
sum_to_shape(x, (3, 1)).eval({x: np.ones((4, 3, 4), dtype=np.float32)})
Small sidenode: any idea why I can't set the correct shape in make_node
?
Yeah, something like that.
Small sidenode: any idea why I can't set the correct shape in make_node?
Type shapes must be non-symbolic tuples, but you converted it to a TensorVariable in the line above
If IfElse allowed for different shapes in the two branches we could also write a symbolic graph that applies the needed logic, but from one of the open issues it seems that both branches must have the same shape.
If you're referring to my comment here, see my follow-up/amendment, which implicitly states that IfElse
is sufficient for the branching necessary in this issue.
That said, let's explore an IfElse
-related approach before introducing new, specialized Op
s.
However Aesara does not provide any building blocks that can achieve this branch logic AFAICT.
It does.
Via what?
I was referring to the branching support required to solve this issue. Again, IfElse
supports this, and I think that some clever use of other, existing Op
s might accomplish the same thing. Finally, a custom Op
that encapsulates the requisite branching logic is always possible, as you folks have been discussing.
In other words, Aesara definitely provides the building blocks for a viable solution to this problem.
So far we've been thinking if it would be best to add an op, or to try and do the required operations with existing ops (eg IfElse). But maybe there is a third option, that might have nice side-effects:
We could also extend the functionality of the Elemwise
op a bit, to something like ElemwiseSum
. I've been thinking about this anyway, because we get a lot of sum(elemwise)
in logp graphs, and it is usually a waste to first store the result, only to add the results up later. Doing both at the same time often speeds up reductions quite a bit.
But wouldn't this allow us to nicely represent what we need for the gradient ops as well, if Elemwise
had an additional parameter sum_outputs_to_shape
or so, that would add a similar logic as the SumToShape
above?
In a way this pretty pushes the problem down to the backends, and maybe that is not what we want, but on the other hand, maybe the backend is actually the place where we can handle this most easily and get good performance?
For numba I've played with ElemwiseSum implementations that work something like this:
@numba.njit(fastmath=True) # fastmath helps a lot here, so that llvm can optimize the reduction
def elemwise_sum(x, y):
# Handle input variable broadcasting, unfortunately this isn't happening in
# the numba impl of `np.iter` currently, which prevents some simd optimizations
# unfortunately....
x, y = np.broadcast(x, y)
total = output_type(0.)
for x_val, y_val in np.iter(x, y):
total += inner_func(x_val, y_val)
If we can find a way to generalize this to partial reductions, then we could at the same time and with the same code improve performance of many sum(elemwise)
and produce code for the gradient of elemwise.
We could also extend the functionality of the
Elemwise
op a bit, to something likeElemwiseSum
. I've been thinking about this anyway, because we get a lot ofsum(elemwise)
in logp graphs, and it is usually a waste to first store the result, only to add the results up later. Doing both at the same time often speeds up reductions quite a bit.
See aesara.tensor.elemwise.CAReduce
.
If we can find a way to generalize this to partial reductions, then we could at the same time and with the same code improve performance of many
sum(elemwise)
and produce code for the gradient of elemwise.
See https://github.com/aesara-devs/aesara/pull/599 and/or the current Numba implementations for CAReduce
Op
s.
Here is an implementation with IfElse, I am quite worried about performance. Specially without the ability to specify a Type which is not broadcastable (shape != 1), the IfElse will have to remain in most use-cases:
import aesara
import aesara.tensor as at
from aesara.ifelse import ifelse
import numpy as np
def unbroadcast_to(value, shape):
for i, (s, vs) in enumerate(zip(shape, value.shape)):
# To increase odds of optimizing away it's perhaps better
# to use the condition `at.neq(s, vs)`
# or allow for for more useless sums
# and only check `at.eq(s, 1)`
value = ifelse(
at.and_(at.eq(s, 1), at.neq(vs, 1)),
at.sum(value, axis=i, keepdims=True),
value,
)
return value
x = at.matrix("x")
y = at.tensor("float64", shape=(5, 5), name="y")
f = aesara.function([x, y], unbroadcast_to(y, x.shape))
y_val = np.random.normal(size=(5, 5))
assert np.allclose(f(np.ones((1, 5)), y_val), np.sum(y_val, axis=0, keepdims=True))
assert np.allclose(f(np.ones((5, 1)), y_val), np.sum(y_val, axis=1, keepdims=True))
assert np.allclose(f(np.ones((1, 1)), y_val), np.sum(y_val, axis=(0, 1), keepdims=True))
assert np.allclose(f(np.ones((5, 5)), y_val), y_val)
print(aesara.dprint(f))
if{inplace} [id A] 10
|Elemwise{Composite{AND(EQ(i0, i1), NEQ(i2, i1))}} [id B] 9
| |Shape_i{1} [id C] 8
| | |x [id D]
| |TensorConstant{1} [id E]
| |if{shape,inplace}.1 [id F] 7
| |Elemwise{eq,no_inplace} [id G] 3
| | |Shape_i{0} [id H] 2
| | | |x [id D]
| | |TensorConstant{1} [id E]
| |TensorConstant{1} [id I]
| |TensorConstant{5} [id J]
| |TensorConstant{5} [id J]
| |TensorConstant{5} [id J]
|InplaceDimShuffle{0,x} [id K] 6
| |Sum{axis=[1], acc_dtype=float64} [id L] 5
| |if{inplace} [id M] 4
| |Elemwise{eq,no_inplace} [id G] 3
| |InplaceDimShuffle{x,0} [id N] 1
| | |Sum{axis=[0], acc_dtype=float64} [id O] 0
| | |y [id P]
| |y [id P]
|if{inplace} [id M] 4
I am quite worried about performance
Why exactly are you worried about this? As far as I can tell, it's part of a run-time cost that we must necessarily pay in order to get the correct results in all cases.
Specially without the ability to specify a Type which is not broadcastable (shape != 1), the IfElse will have to remain in most use-cases:
Don't forget: something/someone needs to specify those shape != 1
constraints, so such optimizations are fundamentally limited by user-level specificity in many cases.
We can—and will—focus on rewriting these graphs so that they're more efficient, though.
Also, when implementing a function like unbroadcast_to
, we will most likely need to use aesara.scalar
Op
s explicitly. That will help us avoid unnecessary circular Elemwise
dependencies, among other things. (This might only be relevant to related issues within Elemwise
, if any, and other Op.infer_shape
implementations.)
N.B. the name unbroadcast_to
is a bit confusing, especially given the similarly named TensorType
-affecting Op
s and functions. Something like sum_broadcasted_axes
is more descriptive/accurate.
Don't forget: something/someone needs to specify those shape != 1 constraints, so such optimizations are fundamentally limited by user-level specificity in many cases.
I am starting to realize that's what Theano always requested users, with the whole broadcast flags business. By default, graphs did not allow for broadcasting, and if users wanted that flexibility they had to provide information about the dims that were to be broadcastable.
However, instead of raising ValueErrors at the Type/input levels, they only raised in the Ops where the distinction would actually matter like Elemwise, so it was not clear (for me at least) that this was actually the design principle. I disagreed with @aseyboldt recently on this.
We went from not allowing broadcasting by default, to always assuming it's possible. We also switched from the Theano distinction of broadcastable (shape=1
vs shape!=1
) at the type level to our new distinction where we only allow users to provide one specific type shape or say nothing at all about the type shapes (shape=x
vs shape=None
).
This gradient issue is where I first see why the Theano design decision might have made sense.
My concerns are two-fold:
I don't have any benchmarks yet so it's all speculation at this point.
@ricardoV94 Yes, a change from "broadcasting is only allowed if we know for sure it's happening" to "we always first assume you can broadcast and figure it out at runtime" seems like something we should think through properly...
About the ElemwiseSum
op, I think an implementation like this would do what we want at least:
a = np.random.randn(300, 400)
b = np.random.randn(400)
out = np.zeros(400)
@numba.njit
def elemwise_func(a, b):
return np.exp(a + b)
@numba.njit(fastmath=True)
def elemwise_sum(a, b, out):
out[...] = 0
a, b, out = np.broadcast_arrays(a, b, out)
for x, y, z in np.nditer((a, b, out)):
z[()] += elemwise_func(x, y)
# Pure python impl using np.nditer
def elemwise_sum_python(a, b, out):
out[...] = 0
it = np.nditer(
(a, b, out),
flags=["external_loop", "reduce_ok", "buffered"],
op_flags=[["readonly"], ["readonly"], ["readwrite"]]
)
with it:
for x, y, z in it:
z[...] += np.exp(x + y)
Numba doesn't seem to generate great code however, but maybe this is actually due to https://github.com/numba/numba/issues/8314.
I'm also not entirely sure if this is supposed to compile, usually np.broadcast_arrays
returns readonly output, but somehow we still end up writing to it in np.nditer
(and seem to get the right result). I'm also not sure if numba will specialize the case where the last iteration axis is continuous, so that it can use SIMD and vectorized exp...
I am starting to realize that's what Theano always requested users, with the whole broadcast flags business. By default, graphs did not allow for broadcasting, and if users wanted that flexibility they had to provide information about the dims that were to be broadcastable.
No, that's not what Theano requested, it's what Elemwise
demanded; otherwise, Elemwise
's interpretation of TensorType.broadcastable
would've been used everywhere outside of Elemwise
and its rewrites.
However, instead of raising ValueErrors at the Type/input levels, they only raised in the Ops where the distinction would actually matter like Elemwise, so it was not clear (for me at least) that this was actually the design principle. I disagreed with @aseyboldt recently on this.
In other words, that choice moved something from compile-time to run-time. That's not the direction we generally want to go, nor is it one with particularly/any notable advantages.
Also, that wasn't a general Theano design principle, because it doesn't match the Type
contract. It was an Elemwise
-specific limitation/constraint that sometimes leaked out of scope, and it looks like people just built around it and assumed it when convenient. It was demonstrably not a universally assumed interpretation of TensorType.broadcastable
values.
We went from not allowing broadcasting by default, to always assuming it's possible. We also switched from the Theano distinction of broadcastable (shape=1 vs shape!=1) at the type level to our new distinction where we only allow users to provide one specific type shape or say nothing at all about the type shapes (shape=x vs shape=None).
No, we didn't start "always assuming it's possible"; we made it possible, and we're still in the process of doing that.
This gradient issue is where I first see why the Theano design decision might have made sense.
What you're seeing is that things can be much simpler for some Elemwise
optimizations under a strict interpretation of TensorType.broadcastable
. That's it.
Unfortunately, simplicity for Elemwise
does not translate to simplicity in general, and that's why this old assumption ultimately ended up being a big design mistake that led to many subtle bugs and unnecessary restrictions over time.
A lot of this is already covered in https://github.com/aesara-devs/aesara/issues/335 and other, related issues/PRs, but I'll recap the basics in the following, because the amount of confusion surrounding this and the apparently strong appeal of naive reversions is alarming.
Let's start with some very simple and direct proof that Elemwise
's strict interpretation is not shared by the rest of Theano
import numpy as np
import theano
import theano.tensor as tt
theano.__version__
# '1.0.5'
x = tt.row("x")
X = tt.matrix("X")
x.broadcastable
# (True, False)
X.broadcastable
# (False, False)
# When a user says a dimension is broadcastable, it must be,
# so the interpretation of `True` is strict.
x.type.filter_variable(np.ones((2, 3)))
# TypeError: ('Non-unit value on shape on a broadcastable dimension.', (2, 3), (True, False))
x.type.is_valid_value(np.ones((2, 3)))
# False
# Theano obviously doesn't interpret `False` in a strict sense,
# because, if it did, these would necessarily fail:
X.type.filter_variable(np.ones((1, 1)))
# TensorConstant{(1, 1) of 1.0}
X.type.is_valid_value(np.ones((1, 1)))
# True
Let's also consider the documentation for a second. This link clearly states that False
TensorType.broadcastable
values are not strict; however, this link says "Theano needs to know" the broadcastable dimensions. What's the difference here?
The latter link is (implicitly) talking about Elemwise
, which is evident from the examples it uses. Since Elemwise
implements broadcasting for all the scalar Op
s in Theano, it's easy to mistakenly conclude that all of Theano assumes Elemwise
's TensorType.broadcastable
interpretation.
Here are some illustrations of Elemwise
's strict and inconsistent interpretation of TensorType.broadcastable
values:
Y = tt.matrix("Y")
XY_fn = theano.function([X, Y], X * Y)
theano.printing.debugprint(XY_fn, print_type=True)
# Elemwise{mul,no_inplace} [id A] <TensorType(float64, matrix)> '' 0
# |X [id B] <TensorType(float64, matrix)>
# |Y [id C] <TensorType(float64, matrix)>
XY_fn(np.ones((2, 2)), np.ones((1, 2)))
# ValueError: Input dimension mis-match. (input[0].shape[0] = 2, input[1].shape[0] = 1)
Everything checks out at the Type
-level with those arguments, but the call fails in Op.perform
(i.e. run-time), and easily avoidable run-time errors aren't good for anything.
Let's try some other arguments that conflict with the False
-means-not-one interpretation:
XY_fn(np.ones((1, 2)), np.ones((1, 2)))
# array([[1., 1.]])
Notice how Elemwise
is selectively interpreting a False
Type.broadcastable
value as an indication that the corresponding shape value is not equal to one, and the check/assumption is only ever effectively applied to dimensions that are actively being broadcasted.
If you're thinking "So what?", then ask yourself: do you want ants? Because this is how you get ants. Inconsistently enforced interpretations are only good for confusion and difficult to debug code, and, with this choice, Theano accomplished both.
It might not seem all that bad when one considers Elemwise
in isolation, but it becomes very serious when Elemwise
s are used in non-trivial graphs.
For example, recall this old Theano bug:
a = tt.scalar()
b = tt.isclose(tt.as_tensor([0, 0.5, 1, -1]), a)
b_fn = theano.function([a], b)
b_fn(0.0)
# ValueError: Input dimension mis-match. (input[0].shape[0] = 4, input[3].shape[0] = 1)
a.type.filter_variable(0.0)
# TensorConstant{0.0}
What's wrong in that example? The graph b
is not ill-defined, no explicit broadcasting was used, and the inputs are all valid.
Basically, tt.isclose
produced a graph that uses Op
s and/or rewrites that either don't subscribe to Elemwise
's strict interpretation of TensorType.broadcastable
, weren't able to perfectly infer the broadcastable properties of their outputs, or any one of Theano's other kludgy workarounds (e.g. use of patternbroadcast
and the like) didn't apply.
Here's that graph, in case you're curious:
theano.printing.debugprint(b_fn, print_type=True)
# Elemwise{Composite{AND(LE(Abs((i0 - i1)), i2), Invert(OR(IsNan(i3), IsInf(i3))))}} [id A] <TensorType(bool, vector)> '' 3
# |TensorConstant{[ 0. 0.5.. 1. -1. ]} [id B] <TensorType(float64, vector)>
# |InplaceDimShuffle{x} [id C] <TensorType(float64, (True,))> '' 0
# | |<TensorType(float64, scalar)> [id D] <TensorType(float64, scalar)>
# |Elemwise{Composite{(i0 + (i1 * Abs(i2)))}} [id E] <TensorType(float64, (True,))> '' 2
# | |TensorConstant{(1,) of 1e-08} [id F] <TensorType(float64, (True,))>
# | |TensorConstant{(1,) of 1e-05} [id G] <TensorType(float64, (True,))>
# | |InplaceDimShuffle{x} [id C] <TensorType(float64, (True,))> '' 0
# |Rebroadcast{0} [id H] <TensorType(float64, vector)> '' 1
# |InplaceDimShuffle{x} [id C] <TensorType(float64, (True,))> '' 0
Why not make all Op
s use the strict interpretation? The reason why it probably wasn't done–and why it's still not good to do–is that we effectively cannot do it without disallowing the use of many Op
s, rewrites, and ways of constructing graphs, because it requires perfect shape/broadcastable inference. A single mistaken TensorType.broadcastable
introduced by one Op.make_node
can propagate arbitrarily through a graph and end up breaking when the corresponding Variable
is used in an Elemwise
.
Remember Elemwise
's selective enforcement of its TensorType.broadcastable
interpretation? Well, that is part of what allowed Elemwise
's overly restrictive interpretation to appear less restrictive than it actually was. Making that interpretation consistent would bring all these new and inherent restrictions to the fore.
By the way, that was not the only such occurrence of this design flaw. These kinds of issues were numerous and, sadly, a regular part of using Theano. It was one source of the ubiquitous "shape issues" that plagued users of Theano and made it a nightmare to debug in non-trivial situations.
The changes made in https://github.com/aesara-devs/aesara/issues/335 and https://github.com/aesara-devs/aesara/pull/928 correct Elemwise
's inconsistent and overly strict interpretation of TensorType.broadcastable
–essentially broadening its applicability and removing its restrictions. A few other parts of the codebase share Elemwise
's restrictions, but–luckily–not the majority of it. From what we've seen, it all emanates from Elemwise
, so the issues resulting from those changes are almost exclusively in code that deals explicitly with it (e.g. Elemwise
-based rewrites). If that weren't the case, virtually nothing would work in Aesara, but the vast majority of it does.
Unfortunately, the changes introduced by the aforementioned PRs were not the end of this Elemwise
refactoring. There are still some things to clean up. (Fortunately, doing so is also helping us identify outstanding gaps in our tests.)
My concerns are two-fold:
- I suspect many optimizations will be broken by the need for the IfElse (or a custom Op) in most gradient graphs, and graphs will be considerably slower
As long as people don't manually disable the "lazy" evaluation feature of the standard VM
s, only one branch of these IfElse
nodes will be evaluated. Furthermore, both branch sub-graphs can–and should–be optimized, just as either one would without the IfElse
.
Also, there are some simple rewrites we could apply. For instance, with the IfElse
fix, some_op(sum(a + b, axis=...))
and some_op(a + b)
graphs would become some_op(ifelse(a_b_shape_constraints, a + b, sum(a + b, axis=...)))
. These IfElse
nodes could be pushed down the graph, recovering the original graphs produced by Elemwise.grad
(e.g. resulting in ifelse(a_b_shape_constraints, some_op(a + b), some_op(sum(a + b, axis=...))))
). Now, all the original rewrites can be applied.
If you're really that concerned, use Assert
(or a subclass) to represent constraints like X.shape[i] != 1
with equivalent in-graph Op
s like at.neq(X.shape[i], 1)
. Next, add a Feature
that picks up these nodes and adds an entry mapping X.shape[i]
to (at.neq, 1)
in a dict
. There's your constraint store. Finally, create a local rewrite that tracks all other boolean Op
s and replaces matching conditions in the constraint store with conforming True
values. Now, those entries can be constant folded–or explicitly removed by the rewrite itself in some cases.
There are numerous variations of this basic idea, but the point is that the specific IfElse
nodes that generalize Elemwise
's gradient graphs are not inherently difficult to remove. Also, doing so could be the start of a much more important set of features/capabilities.
- We won't be able to compile (or at least JIT) gradient graphs to more restricted backends like JAX, unless users specify exact shapes for all the inputs.
What makes you think that? This assertion really needs a concrete example (e.g. which Op
s can we currently not transpile due to our non-strict interpretation of TensorType.broadcastable
values), because it sounds like it's potentially referring to some very unrelated things. For instance, the JAX limitations we currently have are simply due to JAX restrictions. We could work around some of those, but definitely not all or even most, if we had better shape inference (i.e. if we could infer more static/Type
-level shapes).
I don't have any benchmarks yet so it's all speculation at this point.
We could choose between strict Op
restrictions that guarantee the availability of some unknown optimization benefits in certain cases, or no Op
restrictions and the potential for a (temporary) loss of the aforementioned optimizations. (I'm obviously not considering the "do nothing and retain all the old bugs" option.)
We chose the latter. The former would've forced us to fix every other Op.make_node
and Op.infer_shape
except Elemwise
(including non-Aesara Op
s), and the latter forces us to fix Elemwise
and its rewrites–as well as any other Op
s/rewrites that make the same assumptions as Elemwise
, and doesn't appear to be many, given that the vast majority of things still function correctly after these changes.
Again, at the design level, our choice does not prevent us from attaining any of the same performance benefits as the former. We just need to keep moving in the direction we're headed and establish a means of specifying and using more granular static shape constraints, then we can have all the same optimizations as before and more.
A lot of this is already covered in https://github.com/aesara-devs/aesara/issues/335 and other, related issues/PRs, but I'll recap the basics in the following, because the amount of confusion surrounding this and the apparently strong appeal of naive reversions is alarming.
I hope the "appeal to naive reversions" wasn't directed at me.
Let's start with some very simple and direct proof that
Elemwise
's strict interpretation is not shared by the rest of Theanoimport numpy as np import theano import theano.tensor as tt theano.__version__ # '1.0.5' x = tt.row("x") X = tt.matrix("X") x.broadcastable # (True, False) X.broadcastable # (False, False) # When a user says a dimension is broadcastable, it must be, # so the interpretation of `True` is strict. x.type.filter_variable(np.ones((2, 3))) # TypeError: ('Non-unit value on shape on a broadcastable dimension.', (2, 3), (True, False)) x.type.is_valid_value(np.ones((2, 3))) # False # Theano obviously doesn't interpret `False` in a strict sense, # because, if it did, these would necessarily fail: X.type.filter_variable(np.ones((1, 1))) # TensorConstant{(1, 1) of 1.0} X.type.is_valid_value(np.ones((1, 1))) # True
I think what's happening here is that broadcastable is not an isolated property of a Type, but something about how Types behave in Ops that allow for input broadcasting. When you add two matrices of shape (1, m), there is no broadcasting happening, so it's valid to use inputs that are not allowed to broadcast.
In other words, something must have a shape of 1 to be broadcastable but can still have a shape of 1 and not be required to broadcast. The flags were not really about the Types but what you could do with the Types. In practice you could disallow Numpy broadcasting behavior in Theano (or more precisely, Numpy broadcasting behavior was an opt-in deal).
I don't think this is useful for users, but was clearly used by the devs to produce valid gradient graphs.
I don't think we should offer the option to disable Numpy broadcasting at run-time, but we should offer the users the tools to tell Aesara when broadcasting won't be needed so that it can produce simpler graphs.
We can even raise if users said something wouldn't be broadcastable but has a shape of 1. That will make Aesara much more intuitive.
Here are some illustrations of
Elemwise
's strict and inconsistent interpretation ofTensorType.broadcastable
values:Y = tt.matrix("Y") XY_fn = theano.function([X, Y], X * Y) theano.printing.debugprint(XY_fn, print_type=True) # Elemwise{mul,no_inplace} [id A] <TensorType(float64, matrix)> '' 0 # |X [id B] <TensorType(float64, matrix)> # |Y [id C] <TensorType(float64, matrix)> XY_fn(np.ones((2, 2)), np.ones((1, 2))) # ValueError: Input dimension mis-match. (input[0].shape[0] = 2, input[1].shape[0] = 1)
Everything checks out at the
Type
-level with those arguments, but the call fails inOp.perform
(i.e. run-time), and easily avoidable run-time errors aren't good for anything.Let's try some other arguments that conflict with the
False
-means-not-one interpretation:XY_fn(np.ones((1, 2)), np.ones((1, 2))) # array([[1., 1.]])
Notice how
Elemwise
is selectively interpreting aFalse
Type.broadcastable
value as an indication that the corresponding shape value is not equal to one, and the check/assumption is only ever effectively applied to dimensions that are actively being broadcasted.
Again this makes sense if you read it not as saying something about the shape of the inputs, but whether broadcasting is actually needed to combine them. In this case it wasn't, so it's valid to use non-broadcastable (according to Theano rules) inputs..
Unfortunately, the changes introduced by the aforementioned PRs were not the end of this
Elemwise
refactoring. There are still some things to clean up. (Fortunately, doing so is also helping us identify outstanding gaps in our tests.)We chose the latter. The former would've forced us to fix every other
Op.make_node
andOp.infer_shape
exceptElemwise
(including non-AesaraOp
s), and the latter forces us to fixElemwise
and its rewrites–as well as any otherOp
s/rewrites that make the same assumptions asElemwise
, and doesn't appear to be many, given that the vast majority of things still function correctly after these changes.Again, at the design level, our choice does not prevent us from attaining any of the same performance benefits as the former. We just need to keep moving in the direction we're headed and establish a means of specifying and using more granular static shape constraints, then we can have all the same optimizations as before and more.
I think we might need to allow for more granular static types (in particular, shape>1
), which will nevertheless involve quite some work in make_node
and rewrites.
I still want to see how the ifelse fares. I would be very happy if this all turned out to be a non-issue.
Ill try and open a PR to add tests for the bugs in the grad of Elemwise, Gemm, amd BroadcastTo (those are all the Ops I know that allow for runtime broadcasting)
I hope the "appeal to naive reversions" wasn't directed at me.
It was directed at all of us.
I don't think this is useful for users, but was clearly used by the devs to produce valid gradient graphs.
Although we're basically guessing why things were done in Theano, I have a strong feeling that, given the status of IfElse
near the end of Theano's development, the reason a kludge like Elemwise
's use of TensorType.broadcastable
was employed probably has more to do with Theano's lack of support for conditionals.
Put another way, they might have wanted to go down a path like the one we're going down now, but IfElse
wasn't available, finished, or guaranteed to be evaluated efficiently (i.e. lazily).
I gave the ElemwiseSum
approach a bit more thought, and the more I think about it the more I like it :-)
Just to make sure we are all on the same page, here is a summary of how that would go:
Elemwise
op with some extra functionality (or we add a separate op ElemwiseSum
) that takes as extra argument the shape the output should be reduced to. (Alternatively we could make an inplace version, where the argument is directly the output buffer). So Elemwise{scalar_op=exp(a + b)}(a=some_a, b=some_b, out_shape=())
would return np.exp(a + b).sum()
, or Elemwise{scalar_op=exp(a + b)}(a=some_a, b=some_b, out_shape=(100,))
would return np.exp(a + b).sum(all_axis_but_the_last)
.Elemwise
. We want to reduce the output of the gradient Elemwise to the shape of the corresponding input, so grad_a = Elemwise{scalar_op=scalar_grad_op_a}(a, b, a.shape)
Sum(Elemwise)
ops and add a rewrite that optimizes that to Elemwise(out_shape=())
.So the remaining question would be how to implement the Elemwise with an included reduction. It turns out we can reuse the numpy broadcasting rules to do exactly that. The following implementation will already do what we want:
@numba.njit
def elemwise_sum(a, b, out_shape):
out = np.empty(out_shape, dtype=out_dtype)
a, b, out = np.broadcast(a, b, out)
for x, y, z in np.nditer((a, b, out)):
z[()] += elemwise_func(x, y)
return out
Why does this work? If out
is broadcasted along a dimension, the broadcasted version of out
will have stride
set to zero in that axis, so in the iteration we add to the same element of out multiple times.
However, the compiler doesn't know at compile time what the stride of the array will be, so it produces code that works for all strides, so it won't be particularly fast for any. In most applications we only really care about the inner most loop however, so we can specialize that inner loop to make sure we can vectorize it:
import numba
import numpy as np
# Get information about simd vectorization from llvm
#import llvmlite.binding as llvm
#llvm.set_option('', '--debug-only=loop-vectorize,iv-descriptors')
@numba.extending.intrinsic
def address_as_void_pointer(typingctx, src):
""" returns a void pointer from a given memory address """
from numba.core import types, cgutils
sig = types.voidptr(src)
def codegen(cgctx, builder, sig, args):
return builder.inttoptr(args[0], cgutils.voidptr_t)
return sig, codegen
@numba.extending.intrinsic
def make_writable(typingctx, src):
"""Make a readonly array writable."""
from numba.core import types, cgutils
sig = src.copy(readonly=False)(src)
def codegen(cgctx, builder, sig, args):
(arg,) = args
return numba.core.imputils.impl_ret_borrowed(cgctx, builder, sig.return_type, arg)
return sig, codegen
# fastmath=True (is "afn" enough?) will use faster but less
# precise special functions if SVML is installed
# (conda install -c numba icc_rt)
#@numba.njit(fastmath={"afn"})
@numba.njit(fastmath=True)
def elemwise_func(a, b):
return np.exp(a + b)
# Slow reference implementation
@numba.njit
def elemwise_sum_slow(a, b, out):
out[...] = 0
a, b, out = np.broadcast_arrays(a, b, out)
for x, y, z in np.nditer((a, b, out)):
z[()] += elemwise_func(x, y)
# Fast version that tries to use SIMD in some cases
@numba.njit(fastmath={"reassoc", "contract"})
def elemwise_sum(a, b, out):
# TODO There are cases where we don't need this.
# Can we somehow put it into the loop?
out[...] = 0
a, b, out = np.broadcast_arrays(a, b, out)
# Numpy returns readonly arrays by default
# so that users don't get surprised by the
# arrays with stride 0. It is safe to write
# howerver, in this case we *want* the behaviour
# of strides == 0.
a = make_writable(a)
b = make_writable(b)
out = make_writable(out)
# Last loop is reduction and a and b are contiguous
if (
out.strides[-1] == 0
and a.strides[-1] == 8 # TODO use dtype.itemsize
and b.strides[-1] == 8
):
for idx in np.ndindex(out.shape[:-1]):
tmp = 0
x = a[idx]
y = b[idx]
k = 0
N = a.shape[-1]
while k < N:
tmp += elemwise_func(x[k], y[k])
k += 1
if N > 0:
out[idx][0] += tmp
# Last loop is reduction, but a or b are not contiguous
elif out.strides[-1] == 0:
for idx in np.ndindex(out.shape[:-1]):
tmp = 0.
x = a[idx]
y = b[idx]
k = 0
N = a.shape[-1]
while k < N:
tmp += elemwise_func(x[k], y[k])
k += 1
if N > 0:
out[idx][0] += tmp
# All last loops are contiguous
elif (
a.strides[-1] == 8
and b.strides[-1] == 8
and out.strides[-1] == 8
):
for idx in np.ndindex(out.shape[:-1]):
x = a[idx]
y = b[idx]
z = out[idx]
N = a.shape[-1]
# Tell numba that the array is c contiguous
z = numba.carray(address_as_void_pointer(z.ctypes.data), N, dtype=out.dtype)
x = numba.carray(address_as_void_pointer(x.ctypes.data), N, dtype=x.dtype)
y = numba.carray(address_as_void_pointer(y.ctypes.data), N, dtype=y.dtype)
k = 0
while k < N:
z[k] = elemwise_func(x[k], y[k])
k += 1
else:
# This last one doesn't seem to get vectorized at all,
# but I'm not sure why exactly...
for idx in np.ndindex(out.shape[:-1]):
x = a[idx]
y = b[idx]
z = out[idx]
N = a.shape[-1]
assert z.strides[-1] > 0
assert x.strides[-1] >= 0
assert y.strides[-1] >= 0
k = 0
while k < N:
z[k] = elemwise_func(x[k], y[k])
k += 1
# Pure python impl using np.nditer
def elemwise_sum_python(a, b, out):
out[...] = 0
it = np.nditer(
(a, b, out),
flags=["external_loop", "reduce_ok", "buffered"],
op_flags=[["readonly"], ["readonly"], ["readwrite"]]
)
with it:
for x, y, z in it:
#z[...] += np.exp(x + y)
z[...] += elemwise_func(x, y)
This still isn't perfect, but it usually beats other implementations, often by a factor of 5. It could still be improved for cases where the inner loop is very small but the next inner loop is contiguous (or a reduction) as well, or if the input arrays have fortran order.
Some timings on my machine with icc_rt
from numba installed:
I haven't given the question of how that would work in jax any thought yet...
@aseyboldt, in the development of Aesara we need to assign greater value to the use/reuse/extension of existing and more general (sub)systems and features. The custom Op
approach you're advocating, and its very specialized Numba transpilations, can serve as a sort of target for the end results of those systems and features, but we can't develop past/around them; otherwise, they'll end up becoming less and less useful over time.
As I stated above, we need the same level of shape inference and ability to assign constraints that would allow us to produce efficient graphs using existing Op
s in many other situations.
In other words, how can we produce similarly optimized Numba code for the graphs we could generate via rewriting that use existing Op
s?
Sorry, but what is so specialized about an op that does elemwise reduction? It is a basic thing we want all over the place, and it can't really be built from smaller parts either.
It is a basic thing we want all over the place, and it can't really be built from smaller parts either.
We need to establish that definitively first.
Logp graphs are one big reduction, because we want a single logp value. And storing all intermediate values is known to be slow (memory pressure and all that). Just have a quick look at the timings above if you don't believe me or do some experiments yourself.
If you have ideas about how to split that apart into separate ops, I'm really curios how you'd even approach that so that the compiler can still use simd.
Now come on, logp graphs are one big reduction, because we want a single logp value. And storing all intermediate values is known to be slow (memory pressure and all that). Just have a quick look at the timings above if you don't believe me or do some experiments yourself.
I think we're talking about different things, because I'm refering to the graphs produced by Elemwise.grad
, which do not determine the memory usage patterns. Those graphs undergo a few more rewrite passes that ultimately determine that.
My concerns are that we, for example, prematurely use a custom Op
that represents something like at.sum(x, axis=a)
, and this prevents us from reasoning about/with the resulting node as if it were the at.sum
that it actually is.
This bug is an unexpected consequence of https://github.com/aesara-devs/aesara/pull/928 and rewrites that make certain assumptions: https://github.com/aesara-devs/aesara/issues/1089#issuecomment-1291561804
The faulty logic is found here: https://github.com/aesara-devs/aesara/blob/7f4e0ab443826af7f3f48d77bf13027ab6bdff69/aesara/tensor/elemwise.py#L552-L570
This is also likely a problem in the grad of
BroadcastTo
which callsinfer_broadcastable
and which defaults to assuming something will not have broadcasted if a static shape of 1 can't be inferred.https://github.com/aesara-devs/aesara/blob/7f8af9bc28755d93dca3afff2534a8a5f5ecbd80/aesara/tensor/extra_ops.py#L1593
https://github.com/aesara-devs/aesara/blob/7f8af9bc28755d93dca3afff2534a8a5f5ecbd80/aesara/tensor/basic.py#L1313
And also GEMM since #986
I am not sure if there's a good solution to this problem, as we would need an expression with different output shapes depending on whether the runtime inputs are broadcasted or not.
Solution might look something like: https://github.com/aesara-devs/aesara/issues/1089#issuecomment-1202225638