aesara-devs / aesara

Aesara is a Python library for defining, optimizing, and efficiently evaluating mathematical expressions involving multi-dimensional arrays.
https://aesara.readthedocs.io
Other
1.18k stars 155 forks source link

Wrong gradients when inputs are dynamically broadcasted #1089

Open ricardoV94 opened 2 years ago

ricardoV94 commented 2 years ago

This bug is an unexpected consequence of https://github.com/aesara-devs/aesara/pull/928 and rewrites that make certain assumptions: https://github.com/aesara-devs/aesara/issues/1089#issuecomment-1291561804

import aesara
import aesara.tensor as at
import numpy as np

x_row = at.row("x_row")
x_matrix = at.matrix("x_matrix")
y = at.matrix("y")

x_row_grad = at.grad(at.sum(x_row + y), wrt=x_row)
x_matrix_grad = at.grad(at.sum(x_matrix + y), wrt=x_matrix)

f_row = aesara.function([x_row, y], x_row_grad)
print(f_row(np.ones((1, 5)), np.ones((5, 5))))
# [[5. 5. 5. 5. 5.]]

f_matrix = aesara.function([x_matrix, y], x_matrix_grad)
print(f_matrix(np.ones((1, 5)), np.ones((5, 5))))
# [[1. 1. 1. 1. 1.]
#  [1. 1. 1. 1. 1.]
#  [1. 1. 1. 1. 1.]
#  [1. 1. 1. 1. 1.]
#  [1. 1. 1. 1. 1.]]

The faulty logic is found here: https://github.com/aesara-devs/aesara/blob/7f4e0ab443826af7f3f48d77bf13027ab6bdff69/aesara/tensor/elemwise.py#L552-L570

This is also likely a problem in the grad of BroadcastTo which calls infer_broadcastable and which defaults to assuming something will not have broadcasted if a static shape of 1 can't be inferred.

https://github.com/aesara-devs/aesara/blob/7f8af9bc28755d93dca3afff2534a8a5f5ecbd80/aesara/tensor/extra_ops.py#L1593

https://github.com/aesara-devs/aesara/blob/7f8af9bc28755d93dca3afff2534a8a5f5ecbd80/aesara/tensor/basic.py#L1313

And also GEMM since #986

I am not sure if there's a good solution to this problem, as we would need an expression with different output shapes depending on whether the runtime inputs are broadcasted or not.

Solution might look something like: https://github.com/aesara-devs/aesara/issues/1089#issuecomment-1202225638

aseyboldt commented 2 years ago

Uff, this looks bad.

This makes me wish we had dimension objects even more, I think that would solve this problem? But I don't see how we could introduce that as a strict requirement without breaking backward compatibility...

If each dimension had an object associated, and we only allow broadcasting if those object are identical, we would always know when we create nodes how broadcasting would work, even before we know if a dimension happens to have length 1.

So something like:

dim0 = at.Dimension(name="dim0")
dim1 = at.Dimension(name="dim1")
dim2 = at.Dimension(name="dim2")

x = at.dvector("x", dims=(dim0,))
y = at.dvector("y", dims=(dim1,))

x + y  # dims == (dim0, dim1)

x = at.dvector("x", dims=(dim0,))
y = at.dvector("y", dims=(dim0,))

x + y  # dims=(dim0,)

X = at.dmatrix("X", dims=(dim0, dim1))
Y = at.dmatrix("Y", dims=(dim0, dim2))

X + Y  # dims == (dim0, dim1, dim2)

That way we'd always statically know the result shape/dims of broadcasting ops.

brandonwillard commented 2 years ago

This makes me wish we had dimension objects even more, I think that would solve this problem?

It looks like the dimension objects you're talking about would provide essentially the same information/accomplish the same things as specify_shape (alongside the existing shape inference). Perhaps the main difference is that these dimension objects are implicitly intended to work at the Type level? If so, the advantages of that approach aren't apparent at the moment.

aseyboldt commented 2 years ago

No, it is a pretty different idea than specify_shape. specify_shape just fixes the shape to known integer values, it provides no information if two different arrays that happen to have the same shape actually refer to the same dimension.

By dimension I mean some constant thing that has a clear identity and might for instance also have associated coordinates.

So if we have dimensions time and country, and x.dims == (time,); y.dims == (country,), those two axes could never broadcast together, even if by accident we had the same number of time points as countries, because their respective dimensions refer to different, incompatible things.

And if we have two arrays with the same dimension, we still wouldn't know the shape, but only that those arrays have the same, and a conceptually compatible dimension. A type like Array<dims=(ountry,)> would mean that we are providing one value for each country (whatever number of countries we have). It never makes sense to directly add this to Array<dims=('time')>, ie "one value for each point in time", even if n_time == n_countries. So for instance we can redefine the sum as a broadcasting sum with output Array<dims=('time', 'country')>.

This is essentially how broadcasting works in xarray, and in my opinion it is way cleaner, prevents lots of bugs and is in general a joy to work with, compared with the relatively messy numpy behavior, where for instance all hell can break lose if it just so happens that one of your dimensions has length 1. It does require a bit of acclimatization though. :-)

Edit But I guess unless we can find a way to solve the compatibility problems, I'm not sure this is the right place to discuss this I guess.

brandonwillard commented 2 years ago

specify_shape just fixes the shape to known integer values, it provides no information if two different arrays that happen to have the same shape actually refer to the same dimension.

specify_shape can be used to assign the same non-constant Variable to arbitrary dimensions in the shapes of multiple Variables. Combine this with shape inference and it's possible to determine that distinct Variables share exactly the same shapes in those dimensions fixed by specify_shape.

import aesara
import aesara.tensor as at
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt_utils import optimize_graph
from aesara.tensor.basic_opt import ShapeFeature

shape_dim_1 = at.lscalar("time")
shape_dim_2 = at.lscalar("country")

x = at.vector("x")
x = at.specify_shape(x, (shape_dim_1,))

y = at.matrix("y")
y = at.specify_shape(y, (shape_dim_1, shape_dim_2))

shape_feature = ShapeFeature()
fg = FunctionGraph(outputs=[x.shape[0], y.shape[0]], features=[shape_feature])

aesara.dprint(fg)
# Subtensor{int64} [id A] 2
#  |Shape [id B] 1
#  | |SpecifyShape [id C] 0
#  |   |x [id D]
#  |   |time [id E]
#  |ScalarConstant{0} [id F]
# Subtensor{int64} [id G] 5
#  |Shape [id H] 4
#  | |SpecifyShape [id I] 3
#  |   |y [id J]
#  |   |time [id E]
#  |   |country [id K]
#  |ScalarConstant{0} [id L]

fg_opt = optimize_graph(fg)

aesara.dprint(fg_opt)
# time [id A]
# time [id A]

shape_feature.shape_of
# {x: (Shape_i{0}.0,),
#  time: (),
#  SpecifyShape.0: (time,),
#  Shape.0: (TensorConstant{1},),
#  ScalarConstant{0}: (),
#  Subtensor{int64}.0: (),
#  y: (Shape_i{0}.0, Shape_i{1}.0),
#  country: (),
#  SpecifyShape.0: (time, country),
#  Shape.0: (TensorConstant{2},),
#  ScalarConstant{0}: (),
#  Subtensor{int64}.0: (),
#  MakeVector{dtype='int64'}.0: (TensorConstant{2},),
#  InplaceDimShuffle{}.0: (),
#  MakeVector{dtype='int64'}.0: (TensorConstant{1},)}

The main distinction I see here is the desire to refer to dimensions independently from their shapes, but that's not an entirely meaningful distinction in this symbolic context, because a symbolic indexed shape can serve as a direct proxy for a dimension.

aseyboldt commented 2 years ago

Oh, that's neat, I didn't realize you could put in variables! I'll have to play with this and see where I can get with it.

The different broadcasting behavior still does seem like a different issue though, doesn't it?

brandonwillard commented 2 years ago

Oh, that's neat, I didn't realize you could put in variables! I'll have to play with this and see where I can get with it.

There are a lot of places where we could/should be using shape inference, but we don't, so, while this stuff is possible, it's not necessarily used/as integrated as one would expect.

The different broadcasting behavior still does seem like a different issue though, doesn't it?

I'm not entirely clear on this situation yet; I only mentioned specify_shape and shape inference because, when combined, they should—theoretically—allow us to determine (non-)equivalent shapes when static/Type-level information isn't available.

ricardoV94 commented 2 years ago

It seems like we might need a new Op that unbroadcasts (reduces) arrays to a given shape or leaves the input unchanged.

Check https://mostafa-samir.github.io/auto-diff-pt2/#unbroadcasting-adjoints

https://github.com/Mostafa-Samir/Hands-on-Intro-to-Auto-Diff/blob/29a9e5157421e2603846a15bceff21d3b2104f3d/autodiff/grads.py#L103

Something like:

x = at.matrix("x")

# at.reduce_to is probably a better name
# maybe sum is all we will ever need
y1 = at.unbroadcast_to(x, shape=(1, 5), reduce_op="sum”) 
y1.eval({x: np.ones((5, 5))})  # [[5, 5, 5, 5, 5]]

# This won't do anything, but shape may be only 
# known at runtime, as in the example in this issue!
y2 = at.unbroadcast_to(x, shape=(5, 5))
y2.eval({x: np.ones((5, 5))})  # np.ones((5, 5))

# If the shape is not compatible with something that could 
# have been broadcasted to the input shape, an error is raised
y3 = at.unbroadcast_to(x, shape=(2, 5))
y3.eval({x: np.ones((5, 5))})  # ValueError

This was also brought up by @sayam753 and @purna135 in Slack in relation to their work on batched solve where dynamic unbroadcasting gradients also crops up. It was that discussion that led me to suspect of this bug!

Edit: This may be possible already without a specialized Op, if sum allows for symbolic axis? Does it?

In that case we could cook a helper pretty quickly, and perhaps add some rewrite in case the axis are constant folded during compilation/ and a sum with constant axis is more efficient.

Edit: Sum does not allow for variable axis

ricardoV94 commented 2 years ago

Edit: mentioned other related issues in the top comment.

brandonwillard commented 2 years ago

Let's try to move to the Blockwise form of this problem/situation. That way, we can attempt to make some progress on https://github.com/aesara-devs/aesara/issues/695 (e.g. finish implementing Blockwise.Lop in https://github.com/aesara-devs/aesara/pull/757) and address these issues (via inheritance from the Elemwise case).

brandonwillard commented 2 years ago

This is also likely a problem in the grad of BroadcastTo which calls infer_broadcastable and which defaults to assuming something will not have broadcasted if a static shape of 1 can't be inferred.

Yeah, it looks like we might need to add symbolic conditions for those dimensions and let them be simplified later via shape inference and/or constant folding. This is similar to what we do when constructing broadcasted shape graphs.

aseyboldt commented 2 years ago

Maybe it would help if we added an op that precisely computes what actually happens when we broadcast several arrays?

So a wrapper around something like this (I really hope this can be simplified a bit, though). For known shapes this could be constant-folded.

def broadcast_patterns_numpy(*shapes):
    """Given input shapes predict broadcasting behavior.

    Returns
    -------
    - output_shape: List[int]
      The shape of the broadcasted array
    - did_broadcast: List[List[bool]]
      For each input and each output dimension: True, if the input
      was broadcasted in this particular axis
    - dim_index: List[List[Optional[int]]]
      For each input and each output dimension: An index that indicates
      which axis of the input array corresponds to the output axis. If
      the input did not have a axis, None
    """
    out_rank = max(len(shape) for shape in shapes)
    shapes_rev = [shape[::-1] for shape in shapes]

    # by_rank[i] contains the lengths of the
    # input dimensions in position i, counted
    # from the back
    by_rank = [[] for _ in range(out_rank)]
    for shape_rev in shapes_rev:
        i = -1
        for i, length in enumerate(shape_rev):
            by_rank[i].append(length)
        for j in range(i + 1, out_rank):
            by_rank[j].append(None)

    does_broadcast_rev = [[] for _ in shapes]
    dim_index_rev = [[] for _ in shapes]
    output_shape_rev = []
    for rank, lengths in enumerate(by_rank):
        lengths_non_zero = [length for length in lengths if length is not None and length != 1]
        if not lengths:
            out_length = 1
        else:
            out_length = lengths_non_zero[0]
        for length in lengths_non_zero[1:]:
            if length != out_length:
                raise ValueError(f"Could not broadcast at rev rank {rank}: {lengths}")

        output_shape_rev.append(length)

        for input, length in enumerate(lengths):
            if length is None:
                does_broadcast_rev[input].append(True)
                dim_index_rev[input].append(None)
            else:
                does_broadcast_rev[input].append(length == 1 and out_length != 1)
                n_before = sum(idx is not None for idx in dim_index_rev[input])
                dim_index_rev[input].append(len(shapes[input]) - n_before - 1)

    does_broadcast = [bools[::-1] for bools in does_broadcast_rev]
    dim_index = [idxs[::-1] for idxs in dim_index_rev]
    output_shape = output_shape_rev[::-1]

    return output_shape, does_broadcast, dim_index

For example if we broadcast arrays of shape (), (2, 5, 0) and (1, 5, 1) we'd get

out_shape, did_broadcast, dim_index = broadcast_patterns_numpy((), (2, 5, 0), (1, 5, 1))

out_shape == [2, 5, 0]
did_broadcast == [
    [True, True, True],   # The first argument was broadcast in all axes
    [False, False, False],  # The second argument wasn't broadcasted at all
    [True, False, True]   # The third argument was broadcasted in the first and last axis
]
dim_index == [
    [None, None, None],   # New axes where created for all output axis for the first arg
    [0, 1, 2],   # For the second arg all axes were used and their order didn't change
    [0, 1, 2]
]

The dim_index is a little redundant in numpy broadcasting, but it would be really helpful in xarray broadcasting. ;-)

brandonwillard commented 2 years ago

So a wrapper around something like this (I really hope this can be simplified a bit, though). For known shapes this could be constant-folded.

The function I linked to in my previous comment already computes broadcast shapes.

ricardoV94 commented 2 years ago

For example if we broadcast arrays of shape (), (2, 5, 0) and (1, 5, 1) we'd get

What's up with that shape of (2, 5, 0)?

I think that combo is invalid as per numpy broadcasting rules due to the zero?

Maybe it would help if we added an op that precisely computes what actually happens when we broadcast several arrays?

That sounds valid, but it seems a more convoluted answer if you ONLY want to fix this problem.

In these cases we have the input and broadcasted gradient output, so the only thing we need is to reduce that gradient along the dimensions that where of size 1 in the input.

Actually, in the Elemwise case we don't even have to worry about new dims, because make_node adds Dimshuffles to the inputs to align the number of dims (but we will perhaps remove that one day)

So what we need is just:

# perform method of new Op boils down to this
def unbroadcast_to(x, shape):
  axis_to_sum = [
    i
    for i, (s, xs) in enumerate(zip(shape, x.shape))
    if s==1 and xs !=1
  ]

  if not axis_to_sum:
    return x
  return np.sum(x, axis=axis_to_sum, keepdims=True)

In the grad where this issue would crop up we would do something like

grad = unbroadcast_to(bcast_grad, shape=input.shape)

And then we could have some rewrites to try to get rid of this Op during compilation. For instance:

1) if the target shape and the input shapes are known to be equivalent, we can remove the Op

2) if the target shape has some (constant-folded) 1s we can replace the original input by one where we summed the known 1s dimensions already.

1) Hopefully there's already a rewrite to get rid of useless sum along dimensions of size 1. If not, we can add one.

3) if the target shape has no (constant-folded) 1s we can remove the Op (or replace by some Asserts related to the shape if we want to)

ricardoV94 commented 2 years ago

Yeah, it looks like we might need to add symbolic conditions for those dimensions and let them be simplified later via shape inference and/or constant folding. This is similar to what we do when constructing broadcasted shape graphs.

The problem is that without an Op like the one I sketched, the only safe thing to do when you can't be sure ahead of time if something will have had a shape of 1 (or certainly not 1) is to raise in the grad method.

If IfElse allowed for different shapes in the two branches we could also write a symbolic graph that applies the needed logic, but from one of the open issues it seems that both branches must have the same shape.

ricardoV94 commented 2 years ago

Allowing sum to have symbolic axis (as long as keepdims is used, this should be fine for Aesara) would also allow for a simple solution without new Ops. But maybe that would raise a whole new set of problems

brandonwillard commented 2 years ago

The problem is that without an Op like the one I sketched, the only safe thing to do when you can't be sure ahead of time if something will have had a shape of 1 (or certainly not 1) is to raise in the grad method.

Simply put, if we don't have the information at compile time, then it needs to be handled at run-time.
The latter is the only scenario in which an Op would be helpful; however, I have yet to see why a new Op is necessary for anything in this scenario. The reason(s) why a new Op is completely necessary also need to be very clear in order to merge anything that takes such an approach.

brandonwillard commented 2 years ago

Additionally, the description of this issue needs to clarify which result is correct and why.

ricardoV94 commented 2 years ago

Additionally, the description of this issue needs to clarify which result is correct and why.

The gradients should have the same shape of the inputs, so the case where row is used is correct.

This issue will arise for any Op that may or not broadcast its inputs at runtime. If broadcast occurs you need to sum the gradient across the broadcasted dimensions, otherwise you should not. However Aesara does not provide any building blocks that can achieve this branch logic AFAICT.

ricardoV94 commented 2 years ago

Note that explicitly broadcasting all the inputs (like the explicit Dimshuffles introduced by Elemwise) wouldn't fix this either. The gradient of BroadcastTo shares the same limitations of Elemwise.

brandonwillard commented 2 years ago

However Aesara does not provide any building blocks that can achieve this branch logic AFAICT.

It does.

ricardoV94 commented 2 years ago

However Aesara does not provide any building blocks that can achieve this branch logic AFAICT.

It does.

Via what?

ricardoV94 commented 2 years ago

To be clear, we need something that can do the following.

def foo(x, y):
  ...

x = at.matrix("x")
y = np.random.normal(size=(5, 5))
f = aesara.function([x], foo(x, y)

assert f(np.ones((1, 5))) == np.sum(y, axis=0, keepdims=True)
assert f(np.ones((5, 1))) == np.sum(y, axis=1, keepdims=True)
assert f(np.ones((1, 1))) == np.sum(y, axis=(0, 1), keepdims=True)
assert f(np.ones((5, 5))) == y
aseyboldt commented 2 years ago

About the shape 0: Numpy allows that actually (although I think that might be a design flaw and I think this complicates things a bit, but I haven't thought it through):

image

In these cases we have the input and broadcasted gradient output, so the only thing we need is to reduce that gradient along the dimensions that where of size 1 in the input.

Is this actually enough? We still need to know if keepdims=True or keepdims=False, right? But I guess we can use the rank of the arrays for that? I actually thought at first we'd also need to know if broadcasting did in fact occur, but I guess an extra sum(keepdims=True) doesn't hurt if the dimension really has length 1.

ricardoV94 commented 2 years ago

About the shape 0: Numpy allows that actually (although I think that might be a design flaw and I think this complicates things a bit, but I haven't thought it through):

Did not expect that! I don't think that's differentiable anyway :P

Is this actually enough? We still need to know if keepdims=True or keepdims=False, right? But I guess we can use the rank of the arrays for that?

Yeah extra dims is easy to know. In the Elemwise case, for now, its always a case of keepdims=True because Elemwise adds new dims to the inputs in make_node via Dimshuffles automatically, so the loss of dims is handled by the grad of Dimshuffle. But we can try to make this more useful and also account for those cases.

I actually thought at first we'd also need to know if broadcasting did in fact occur, but I guess an extra sum(keepdims=True) doesn't hurt if the dimension really has length 1.

Yeah I assumed that wouldn't hurt, but we can be more clever about it if needed.

aseyboldt commented 2 years ago

Did not expect that! I don't think that's differentiable anyway :P

Yes it is, the derivative is just an empty array. :-)

I wrote a couple of tests, maybe we can convert those into pytest.parametrized... (and maybe check the grad numerically)

```python import numpy as np import aesara.tensor as at import aesara def broadcast_patterns_numpy(*shapes): """Given input shapes predict broadcasting behavior. Returns ------- - output_shape: List[int] The shape of the broadcasted array - did_broadcast: List[List[bool]] For each input and each output dimension: True, if the input was broadcasted in this particular axis - dim_index: List[List[Optional[int]]] For each input and each output dimension: An index that indicates which axis of the input array corresponds to the output axis. If the input did not have a axis, None """ out_rank = max(len(shape) for shape in shapes) shapes_rev = [shape[::-1] for shape in shapes] # by_rank[i] contains the lengths of the # input dimensions in position i, counted # from the back by_rank = [[] for _ in range(out_rank)] for shape_rev in shapes_rev: i = -1 for i, length in enumerate(shape_rev): by_rank[i].append(length) for j in range(i + 1, out_rank): by_rank[j].append(None) does_broadcast_rev = [[] for _ in shapes] dim_index_rev = [[] for _ in shapes] output_shape_rev = [] for rank, lengths in enumerate(by_rank): lengths_non_zero = [length for length in lengths if length is not None and length != 1] if len(lengths_non_zero) == 0: out_length = 1 else: out_length = lengths_non_zero[0] for length in lengths_non_zero[1:]: if length != out_length: raise ValueError(f"Could not broadcast at rev rank {rank}: {lengths}") output_shape_rev.append(out_length) for input, length in enumerate(lengths): if length is None: does_broadcast_rev[input].append(True) dim_index_rev[input].append(None) else: does_broadcast_rev[input].append(length == 1 and out_length != 1) n_before = sum(idx is not None for idx in dim_index_rev[input]) dim_index_rev[input].append(len(shapes[input]) - n_before - 1) does_broadcast = [bools[::-1] for bools in does_broadcast_rev] dim_index = [idxs[::-1] for idxs in dim_index_rev] output_shape = output_shape_rev[::-1] return output_shape, does_broadcast, dim_index def test_broadcasting_grad(known_shapes, input_shapes, dtypes, out_shape, out_dtype): inputs = [ at.tensor(shape=shape if known_shapes else [None for _ in shape], dtype=dtype) for shape, dtype in zip(input_shapes, dtypes) ] out = sum(inputs) vals = [ np.zeros(shape=shape, dtype=dtype) for shape, dtype in zip(input_shapes, dtypes) ] assert np.dtype(out.type.dtype) == np.dtype(out_dtype) out_shape_, did_broadcast, dim_index = broadcast_patterns_numpy(*input_shapes) assert tuple(out_shape_) == tuple(out_shape), f"incorrect output shape. Is {out_shape} but should be {out_shape_}" out_grad_val = np.zeros(out_shape_, dtype=out_dtype) # Make sure sums are unique out_grad_val.ravel()[...] = 2 ** np.arange(len(out_grad_val.ravel())) out_grad = at.tensor(shape=out_shape, dtype=out_dtype) grads = aesara.Lop(out, inputs, out_grad) func = aesara.function([out_grad] + inputs, grads, on_unused_input="ignore") grad_vals = func(out_grad_val, *vals) for i, (grad, dtype, shape, did_bc, index) in enumerate(zip(grad_vals, dtypes, input_shapes, did_broadcast, dim_index)): assert tuple(grad.shape) == tuple(shape), f"Grad {i} should be {shape} but is {grad.shape}" assert np.dtype(grad.dtype) == np.dtype(dtype) do_sum = np.array(did_bc).nonzero()[0] expected_grad = out_grad_val if len(do_sum) > 0: expected_grad = out_grad_val.sum(axis=tuple(do_sum), keepdims=True) remove_dim = [idx is None for idx in index] do_sum = np.array(remove_dim).nonzero()[0] if len(do_sum) > 0: expected_grad = expected_grad.sum(axis=tuple(do_sum), keepdims=False) assert tuple(expected_grad.shape) == tuple(grad.shape), f"Grad {i} should be {expected_grad.shape} but is {grad.shape}" assert np.allclose(expected_grad, grad), "Incorrect vals" cases = [ ( [(2,), (2,)], [np.float64, np.float64], (2,), np.float64, ), ( [(1,), (2,)], [np.float64, np.float64], (2,), np.float64, ), ( [(0,), (1,)], [np.float64, np.float64], (0,), np.float64, ), ( [(1,), ()], [np.float64, np.float64], (1,), np.float64, ), ( [(2,), ()], [np.float64, np.float64], (2,), np.float64, ), ( [(), (2,)], [np.float64, np.float64], (2,), np.float64, ), ( [(2, 1), (2, 2)], [np.float64, np.float64], (2, 2), np.float64, ), ( [(2, 3), (2, 3)], [np.float32, np.float64], (2, 3), np.float64, ), ( [(2, 3), (2, 3)], [np.float32, np.float32], (2, 3), np.float32, ), ( [(2, 3), ()], [np.float32, np.float32], (2, 3), np.float32, ), ( [(2, 3, 4, 1), (1, 3, 1, 5), (5,)], [np.float64, np.float64], (2, 3, 4, 5), np.float64, ), ( [(0, 3, 4, 1), (1, 3, 1, 5), (5,)], [np.float64, np.float64], (0, 3, 4, 5), np.float64, ), ( [(2, 3), (), (1, 1, 3)], [np.float32, np.float32], (1, 2, 3), np.float32, ), ] for case in cases: try: test_broadcasting_grad(True, *case) test_broadcasting_grad(False, *case) except Exception as e: print(case[0], str(e)) ```

Currently, I get test failures for those input shapes:

[(1,), (2,)] Grad 0 should be (1,) but is (2,)
[(0,), (1,)] Grad 1 should be (1,) but is (0,)
[(2, 1), (2, 2)] Grad 0 should be (2, 1) but is (2, 2)
[(2, 3, 4, 1), (1, 3, 1, 5), (5,)] Grad 0 should be (2, 3, 4, 1) but is (2, 3, 4, 5)
[(0, 3, 4, 1), (1, 3, 1, 5), (5,)] Grad 0 should be (0, 3, 4, 1) but is (0, 3, 4, 5)
[(2, 3), (), (1, 1, 3)] Elemwise{add,no_inplace}.grad returned a term with 3 dimensions, but 2 are required.
aseyboldt commented 2 years ago

Found another error message:

    (
        [(1, 1, 1), (), (2, 1)],
        [np.float32, np.float32],
        (1, 2, 1),
        np.float32,
    ),
[(1, 1, 1), (), (2, 1)] Cannot drop a non-broadcastable dimension: (True, False, True), []
ricardoV94 commented 2 years ago

Do these errors come from known_shapes=False?

aseyboldt commented 2 years ago

Most do. However these also happen with know_shapes=True:

[(2, 3), (), (1, 1, 3)] Elemwise{add,no_inplace}.grad returned a term with 3 dimensions, but 2 are required.
[(1, 1, 1), (), (2, 1)] Cannot drop a non-broadcastable dimension: (True, False, True), []
[(1, 1), (), (2,)] Cannot drop a non-broadcastable dimension: (True, False), []
ricardoV94 commented 2 years ago

The "cannot drop non-broadcastable dimension" sounds like another legacy of treating unknown dims as != 1, but I am surprised that it shows up here.

The "expected 2 but got 3" I am not sure.

aseyboldt commented 2 years ago

@ricardoV94 You mean something like this with broadcast_to?

class SumToShape(at.Op):
    __props__ = ()

    def make_node(self, tensor, target_shape):
        dtype = tensor.type.dtype
        target_shape = at.as_tensor_variable(target_shape)
        #output = type(tensor.type)(dtype=dtype, shape=target_shape)()   # Why doesn't this work?
        output = type(tensor.type)(dtype=dtype, shape=[None for _ in target_shape])()
        return aesara.tensor.basic.Apply(self, [tensor, target_shape], [output])

    def perform(self, node, inputs, outputs):
        (tensor, target_shape) = inputs
        (output,) = outputs

        leading_dims = len(tensor.shape) - len(target_shape)
        assert leading_dims >= 0

        sum_axis = [True] * leading_dims
        for actual_length, target_length in zip(tensor.shape[leading_dims:], target_shape):
            not_equal = actual_length != target_length
            if not_equal:
                assert target_length == 1
            sum_axis.append(not_equal)
        sum_axis = np.nonzero(sum_axis)[0]

        output_value = np.sum(tensor, axis=tuple(sum_axis), keepdims=True)
        output_value = output_value[(0,) * leading_dims]
        output[0] = output_value

_SumToShape = SumToShape()
def sum_to_shape(value, shape):
    return _SumToShape(value, shape)

x = at.tensor(shape=(None, None, None), dtype=np.float32)

sum_to_shape(x, (3, 1)).eval({x: np.ones((4, 3, 4), dtype=np.float32)})

Small sidenode: any idea why I can't set the correct shape in make_node?

ricardoV94 commented 2 years ago

Yeah, something like that.

Small sidenode: any idea why I can't set the correct shape in make_node?

Type shapes must be non-symbolic tuples, but you converted it to a TensorVariable in the line above

brandonwillard commented 2 years ago

If IfElse allowed for different shapes in the two branches we could also write a symbolic graph that applies the needed logic, but from one of the open issues it seems that both branches must have the same shape.

If you're referring to my comment here, see my follow-up/amendment, which implicitly states that IfElse is sufficient for the branching necessary in this issue.

That said, let's explore an IfElse-related approach before introducing new, specialized Ops.

However Aesara does not provide any building blocks that can achieve this branch logic AFAICT.

It does.

Via what?

I was referring to the branching support required to solve this issue. Again, IfElse supports this, and I think that some clever use of other, existing Ops might accomplish the same thing. Finally, a custom Op that encapsulates the requisite branching logic is always possible, as you folks have been discussing.
In other words, Aesara definitely provides the building blocks for a viable solution to this problem.

aseyboldt commented 2 years ago

So far we've been thinking if it would be best to add an op, or to try and do the required operations with existing ops (eg IfElse). But maybe there is a third option, that might have nice side-effects:

We could also extend the functionality of the Elemwise op a bit, to something like ElemwiseSum. I've been thinking about this anyway, because we get a lot of sum(elemwise) in logp graphs, and it is usually a waste to first store the result, only to add the results up later. Doing both at the same time often speeds up reductions quite a bit.

But wouldn't this allow us to nicely represent what we need for the gradient ops as well, if Elemwise had an additional parameter sum_outputs_to_shape or so, that would add a similar logic as the SumToShape above? In a way this pretty pushes the problem down to the backends, and maybe that is not what we want, but on the other hand, maybe the backend is actually the place where we can handle this most easily and get good performance?

For numba I've played with ElemwiseSum implementations that work something like this:

@numba.njit(fastmath=True)  # fastmath helps a lot here, so that llvm can optimize the reduction
def elemwise_sum(x, y):
    # Handle input variable broadcasting, unfortunately this isn't happening in
    # the numba impl of `np.iter` currently, which prevents some simd optimizations
    # unfortunately....
    x, y = np.broadcast(x, y)
    total = output_type(0.)
    for x_val, y_val in np.iter(x, y):
        total += inner_func(x_val, y_val)

If we can find a way to generalize this to partial reductions, then we could at the same time and with the same code improve performance of many sum(elemwise) and produce code for the gradient of elemwise.

brandonwillard commented 2 years ago

We could also extend the functionality of the Elemwise op a bit, to something like ElemwiseSum. I've been thinking about this anyway, because we get a lot of sum(elemwise) in logp graphs, and it is usually a waste to first store the result, only to add the results up later. Doing both at the same time often speeds up reductions quite a bit.

See aesara.tensor.elemwise.CAReduce.

If we can find a way to generalize this to partial reductions, then we could at the same time and with the same code improve performance of many sum(elemwise) and produce code for the gradient of elemwise.

See https://github.com/aesara-devs/aesara/pull/599 and/or the current Numba implementations for CAReduce Ops.

ricardoV94 commented 2 years ago

Here is an implementation with IfElse, I am quite worried about performance. Specially without the ability to specify a Type which is not broadcastable (shape != 1), the IfElse will have to remain in most use-cases:

import aesara
import aesara.tensor as at
from aesara.ifelse import ifelse
import numpy as np

def unbroadcast_to(value, shape):
    for i, (s, vs) in enumerate(zip(shape, value.shape)):
        # To increase odds of optimizing away it's perhaps better
        # to use the condition `at.neq(s, vs)`
        # or allow for for more useless sums
        # and only check `at.eq(s, 1)`
        value = ifelse(
            at.and_(at.eq(s, 1), at.neq(vs, 1)), 
            at.sum(value, axis=i, keepdims=True), 
            value,
        )
    return value

x = at.matrix("x")
y = at.tensor("float64", shape=(5, 5), name="y")
f = aesara.function([x, y], unbroadcast_to(y, x.shape))

y_val = np.random.normal(size=(5, 5))

assert np.allclose(f(np.ones((1, 5)), y_val), np.sum(y_val, axis=0, keepdims=True))
assert np.allclose(f(np.ones((5, 1)), y_val), np.sum(y_val, axis=1, keepdims=True))
assert np.allclose(f(np.ones((1, 1)), y_val), np.sum(y_val, axis=(0, 1), keepdims=True))
assert np.allclose(f(np.ones((5, 5)), y_val), y_val)

print(aesara.dprint(f))
if{inplace} [id A] 10
 |Elemwise{Composite{AND(EQ(i0, i1), NEQ(i2, i1))}} [id B] 9
 | |Shape_i{1} [id C] 8
 | | |x [id D]
 | |TensorConstant{1} [id E]
 | |if{shape,inplace}.1 [id F] 7
 |   |Elemwise{eq,no_inplace} [id G] 3
 |   | |Shape_i{0} [id H] 2
 |   | | |x [id D]
 |   | |TensorConstant{1} [id E]
 |   |TensorConstant{1} [id I]
 |   |TensorConstant{5} [id J]
 |   |TensorConstant{5} [id J]
 |   |TensorConstant{5} [id J]
 |InplaceDimShuffle{0,x} [id K] 6
 | |Sum{axis=[1], acc_dtype=float64} [id L] 5
 |   |if{inplace} [id M] 4
 |     |Elemwise{eq,no_inplace} [id G] 3
 |     |InplaceDimShuffle{x,0} [id N] 1
 |     | |Sum{axis=[0], acc_dtype=float64} [id O] 0
 |     |   |y [id P]
 |     |y [id P]
 |if{inplace} [id M] 4
brandonwillard commented 2 years ago

I am quite worried about performance

Why exactly are you worried about this? As far as I can tell, it's part of a run-time cost that we must necessarily pay in order to get the correct results in all cases.

Specially without the ability to specify a Type which is not broadcastable (shape != 1), the IfElse will have to remain in most use-cases:

Don't forget: something/someone needs to specify those shape != 1 constraints, so such optimizations are fundamentally limited by user-level specificity in many cases.

We can—and will—focus on rewriting these graphs so that they're more efficient, though.

brandonwillard commented 2 years ago

Also, when implementing a function like unbroadcast_to, we will most likely need to use aesara.scalar Ops explicitly. That will help us avoid unnecessary circular Elemwise dependencies, among other things. (This might only be relevant to related issues within Elemwise, if any, and other Op.infer_shape implementations.)

N.B. the name unbroadcast_to is a bit confusing, especially given the similarly named TensorType-affecting Ops and functions. Something like sum_broadcasted_axes is more descriptive/accurate.

ricardoV94 commented 2 years ago

Don't forget: something/someone needs to specify those shape != 1 constraints, so such optimizations are fundamentally limited by user-level specificity in many cases.

I am starting to realize that's what Theano always requested users, with the whole broadcast flags business. By default, graphs did not allow for broadcasting, and if users wanted that flexibility they had to provide information about the dims that were to be broadcastable.

However, instead of raising ValueErrors at the Type/input levels, they only raised in the Ops where the distinction would actually matter like Elemwise, so it was not clear (for me at least) that this was actually the design principle. I disagreed with @aseyboldt recently on this.

We went from not allowing broadcasting by default, to always assuming it's possible. We also switched from the Theano distinction of broadcastable (shape=1 vs shape!=1) at the type level to our new distinction where we only allow users to provide one specific type shape or say nothing at all about the type shapes (shape=x vs shape=None).

This gradient issue is where I first see why the Theano design decision might have made sense.

My concerns are two-fold:

I don't have any benchmarks yet so it's all speculation at this point.

aseyboldt commented 2 years ago

@ricardoV94 Yes, a change from "broadcasting is only allowed if we know for sure it's happening" to "we always first assume you can broadcast and figure it out at runtime" seems like something we should think through properly...

About the ElemwiseSum op, I think an implementation like this would do what we want at least:

a = np.random.randn(300, 400)
b = np.random.randn(400)
out = np.zeros(400)

@numba.njit
def elemwise_func(a, b):
    return np.exp(a + b)

@numba.njit(fastmath=True)
def elemwise_sum(a, b, out):
    out[...] = 0
    a, b, out = np.broadcast_arrays(a, b, out)
    for x, y, z in np.nditer((a, b, out)):
        z[()] += elemwise_func(x, y)

# Pure python impl using np.nditer
def elemwise_sum_python(a, b, out):
    out[...] = 0
    it = np.nditer(
        (a, b, out),
        flags=["external_loop", "reduce_ok", "buffered"],
        op_flags=[["readonly"], ["readonly"], ["readwrite"]]
    )
    with it:
        for x, y, z in it:
            z[...] += np.exp(x + y)

Numba doesn't seem to generate great code however, but maybe this is actually due to https://github.com/numba/numba/issues/8314.

I'm also not entirely sure if this is supposed to compile, usually np.broadcast_arrays returns readonly output, but somehow we still end up writing to it in np.nditer (and seem to get the right result). I'm also not sure if numba will specialize the case where the last iteration axis is continuous, so that it can use SIMD and vectorized exp...

brandonwillard commented 2 years ago

I am starting to realize that's what Theano always requested users, with the whole broadcast flags business. By default, graphs did not allow for broadcasting, and if users wanted that flexibility they had to provide information about the dims that were to be broadcastable.

No, that's not what Theano requested, it's what Elemwise demanded; otherwise, Elemwise's interpretation of TensorType.broadcastable would've been used everywhere outside of Elemwise and its rewrites.

However, instead of raising ValueErrors at the Type/input levels, they only raised in the Ops where the distinction would actually matter like Elemwise, so it was not clear (for me at least) that this was actually the design principle. I disagreed with @aseyboldt recently on this.

In other words, that choice moved something from compile-time to run-time. That's not the direction we generally want to go, nor is it one with particularly/any notable advantages.

Also, that wasn't a general Theano design principle, because it doesn't match the Type contract. It was an Elemwise-specific limitation/constraint that sometimes leaked out of scope, and it looks like people just built around it and assumed it when convenient. It was demonstrably not a universally assumed interpretation of TensorType.broadcastable values.

We went from not allowing broadcasting by default, to always assuming it's possible. We also switched from the Theano distinction of broadcastable (shape=1 vs shape!=1) at the type level to our new distinction where we only allow users to provide one specific type shape or say nothing at all about the type shapes (shape=x vs shape=None).

No, we didn't start "always assuming it's possible"; we made it possible, and we're still in the process of doing that.

This gradient issue is where I first see why the Theano design decision might have made sense.

What you're seeing is that things can be much simpler for some Elemwise optimizations under a strict interpretation of TensorType.broadcastable. That's it.

Unfortunately, simplicity for Elemwise does not translate to simplicity in general, and that's why this old assumption ultimately ended up being a big design mistake that led to many subtle bugs and unnecessary restrictions over time.

A lot of this is already covered in https://github.com/aesara-devs/aesara/issues/335 and other, related issues/PRs, but I'll recap the basics in the following, because the amount of confusion surrounding this and the apparently strong appeal of naive reversions is alarming.

Let's start with some very simple and direct proof that Elemwise's strict interpretation is not shared by the rest of Theano

import numpy as np

import theano
import theano.tensor as tt

theano.__version__
# '1.0.5'

x = tt.row("x")
X = tt.matrix("X")

x.broadcastable
# (True, False)

X.broadcastable
# (False, False)

# When a user says a dimension is broadcastable, it must be,
# so the interpretation of `True` is strict.
x.type.filter_variable(np.ones((2, 3)))
# TypeError: ('Non-unit value on shape on a broadcastable dimension.', (2, 3), (True, False))

x.type.is_valid_value(np.ones((2, 3)))
# False

# Theano obviously doesn't interpret `False` in a strict sense,
# because, if it did, these would necessarily fail:
X.type.filter_variable(np.ones((1, 1)))
# TensorConstant{(1, 1) of 1.0}

X.type.is_valid_value(np.ones((1, 1)))
# True

Let's also consider the documentation for a second. This link clearly states that False TensorType.broadcastable values are not strict; however, this link says "Theano needs to know" the broadcastable dimensions. What's the difference here?

The latter link is (implicitly) talking about Elemwise, which is evident from the examples it uses. Since Elemwise implements broadcasting for all the scalar Ops in Theano, it's easy to mistakenly conclude that all of Theano assumes Elemwise's TensorType.broadcastable interpretation.

Here are some illustrations of Elemwise's strict and inconsistent interpretation of TensorType.broadcastable values:

Y = tt.matrix("Y")

XY_fn = theano.function([X, Y], X * Y)

theano.printing.debugprint(XY_fn, print_type=True)
# Elemwise{mul,no_inplace} [id A] <TensorType(float64, matrix)> ''   0
#  |X [id B] <TensorType(float64, matrix)>
#  |Y [id C] <TensorType(float64, matrix)>

XY_fn(np.ones((2, 2)), np.ones((1, 2)))
# ValueError: Input dimension mis-match. (input[0].shape[0] = 2, input[1].shape[0] = 1)

Everything checks out at the Type-level with those arguments, but the call fails in Op.perform (i.e. run-time), and easily avoidable run-time errors aren't good for anything.

Let's try some other arguments that conflict with the False-means-not-one interpretation:

XY_fn(np.ones((1, 2)), np.ones((1, 2)))
# array([[1., 1.]])

Notice how Elemwise is selectively interpreting a False Type.broadcastable value as an indication that the corresponding shape value is not equal to one, and the check/assumption is only ever effectively applied to dimensions that are actively being broadcasted.

If you're thinking "So what?", then ask yourself: do you want ants? Because this is how you get ants. Inconsistently enforced interpretations are only good for confusion and difficult to debug code, and, with this choice, Theano accomplished both.

It might not seem all that bad when one considers Elemwise in isolation, but it becomes very serious when Elemwises are used in non-trivial graphs.

For example, recall this old Theano bug:

a = tt.scalar()
b = tt.isclose(tt.as_tensor([0, 0.5, 1, -1]), a)

b_fn = theano.function([a], b)

b_fn(0.0)
# ValueError: Input dimension mis-match. (input[0].shape[0] = 4, input[3].shape[0] = 1)
a.type.filter_variable(0.0)
# TensorConstant{0.0}

What's wrong in that example? The graph b is not ill-defined, no explicit broadcasting was used, and the inputs are all valid.

Basically, tt.isclose produced a graph that uses Ops and/or rewrites that either don't subscribe to Elemwise's strict interpretation of TensorType.broadcastable, weren't able to perfectly infer the broadcastable properties of their outputs, or any one of Theano's other kludgy workarounds (e.g. use of patternbroadcast and the like) didn't apply.

Here's that graph, in case you're curious:

theano.printing.debugprint(b_fn, print_type=True)
# Elemwise{Composite{AND(LE(Abs((i0 - i1)), i2), Invert(OR(IsNan(i3), IsInf(i3))))}} [id A] <TensorType(bool, vector)> ''   3
#  |TensorConstant{[ 0.   0.5.. 1.  -1. ]} [id B] <TensorType(float64, vector)>
#  |InplaceDimShuffle{x} [id C] <TensorType(float64, (True,))> ''   0
#  | |<TensorType(float64, scalar)> [id D] <TensorType(float64, scalar)>
#  |Elemwise{Composite{(i0 + (i1 * Abs(i2)))}} [id E] <TensorType(float64, (True,))> ''   2
#  | |TensorConstant{(1,) of 1e-08} [id F] <TensorType(float64, (True,))>
#  | |TensorConstant{(1,) of 1e-05} [id G] <TensorType(float64, (True,))>
#  | |InplaceDimShuffle{x} [id C] <TensorType(float64, (True,))> ''   0
#  |Rebroadcast{0} [id H] <TensorType(float64, vector)> ''   1
#    |InplaceDimShuffle{x} [id C] <TensorType(float64, (True,))> ''   0

Why not make all Ops use the strict interpretation? The reason why it probably wasn't done–and why it's still not good to do–is that we effectively cannot do it without disallowing the use of many Ops, rewrites, and ways of constructing graphs, because it requires perfect shape/broadcastable inference. A single mistaken TensorType.broadcastable introduced by one Op.make_node can propagate arbitrarily through a graph and end up breaking when the corresponding Variable is used in an Elemwise.

Remember Elemwise's selective enforcement of its TensorType.broadcastable interpretation? Well, that is part of what allowed Elemwise's overly restrictive interpretation to appear less restrictive than it actually was. Making that interpretation consistent would bring all these new and inherent restrictions to the fore.

By the way, that was not the only such occurrence of this design flaw. These kinds of issues were numerous and, sadly, a regular part of using Theano. It was one source of the ubiquitous "shape issues" that plagued users of Theano and made it a nightmare to debug in non-trivial situations.

The changes made in https://github.com/aesara-devs/aesara/issues/335 and https://github.com/aesara-devs/aesara/pull/928 correct Elemwise's inconsistent and overly strict interpretation of TensorType.broadcastable–essentially broadening its applicability and removing its restrictions. A few other parts of the codebase share Elemwise's restrictions, but–luckily–not the majority of it. From what we've seen, it all emanates from Elemwise, so the issues resulting from those changes are almost exclusively in code that deals explicitly with it (e.g. Elemwise-based rewrites). If that weren't the case, virtually nothing would work in Aesara, but the vast majority of it does.

Unfortunately, the changes introduced by the aforementioned PRs were not the end of this Elemwise refactoring. There are still some things to clean up. (Fortunately, doing so is also helping us identify outstanding gaps in our tests.)

My concerns are two-fold:

  • I suspect many optimizations will be broken by the need for the IfElse (or a custom Op) in most gradient graphs, and graphs will be considerably slower

As long as people don't manually disable the "lazy" evaluation feature of the standard VMs, only one branch of these IfElse nodes will be evaluated. Furthermore, both branch sub-graphs can–and should–be optimized, just as either one would without the IfElse.

Also, there are some simple rewrites we could apply. For instance, with the IfElse fix, some_op(sum(a + b, axis=...)) and some_op(a + b) graphs would become some_op(ifelse(a_b_shape_constraints, a + b, sum(a + b, axis=...))). These IfElse nodes could be pushed down the graph, recovering the original graphs produced by Elemwise.grad (e.g. resulting in ifelse(a_b_shape_constraints, some_op(a + b), some_op(sum(a + b, axis=...))))). Now, all the original rewrites can be applied.

If you're really that concerned, use Assert (or a subclass) to represent constraints like X.shape[i] != 1 with equivalent in-graph Ops like at.neq(X.shape[i], 1). Next, add a Feature that picks up these nodes and adds an entry mapping X.shape[i] to (at.neq, 1) in a dict. There's your constraint store. Finally, create a local rewrite that tracks all other boolean Ops and replaces matching conditions in the constraint store with conforming True values. Now, those entries can be constant folded–or explicitly removed by the rewrite itself in some cases.

There are numerous variations of this basic idea, but the point is that the specific IfElse nodes that generalize Elemwise's gradient graphs are not inherently difficult to remove. Also, doing so could be the start of a much more important set of features/capabilities.

  • We won't be able to compile (or at least JIT) gradient graphs to more restricted backends like JAX, unless users specify exact shapes for all the inputs.

What makes you think that? This assertion really needs a concrete example (e.g. which Ops can we currently not transpile due to our non-strict interpretation of TensorType.broadcastable values), because it sounds like it's potentially referring to some very unrelated things. For instance, the JAX limitations we currently have are simply due to JAX restrictions. We could work around some of those, but definitely not all or even most, if we had better shape inference (i.e. if we could infer more static/Type-level shapes).

I don't have any benchmarks yet so it's all speculation at this point.

We could choose between strict Op restrictions that guarantee the availability of some unknown optimization benefits in certain cases, or no Op restrictions and the potential for a (temporary) loss of the aforementioned optimizations. (I'm obviously not considering the "do nothing and retain all the old bugs" option.)

We chose the latter. The former would've forced us to fix every other Op.make_node and Op.infer_shape except Elemwise (including non-Aesara Ops), and the latter forces us to fix Elemwise and its rewrites–as well as any other Ops/rewrites that make the same assumptions as Elemwise, and doesn't appear to be many, given that the vast majority of things still function correctly after these changes.

Again, at the design level, our choice does not prevent us from attaining any of the same performance benefits as the former. We just need to keep moving in the direction we're headed and establish a means of specifying and using more granular static shape constraints, then we can have all the same optimizations as before and more.

ricardoV94 commented 2 years ago

A lot of this is already covered in https://github.com/aesara-devs/aesara/issues/335 and other, related issues/PRs, but I'll recap the basics in the following, because the amount of confusion surrounding this and the apparently strong appeal of naive reversions is alarming.

I hope the "appeal to naive reversions" wasn't directed at me.

Let's start with some very simple and direct proof that Elemwise's strict interpretation is not shared by the rest of Theano

import numpy as np

import theano
import theano.tensor as tt

theano.__version__
# '1.0.5'

x = tt.row("x")
X = tt.matrix("X")

x.broadcastable
# (True, False)

X.broadcastable
# (False, False)

# When a user says a dimension is broadcastable, it must be,
# so the interpretation of `True` is strict.
x.type.filter_variable(np.ones((2, 3)))
# TypeError: ('Non-unit value on shape on a broadcastable dimension.', (2, 3), (True, False))

x.type.is_valid_value(np.ones((2, 3)))
# False

# Theano obviously doesn't interpret `False` in a strict sense,
# because, if it did, these would necessarily fail:
X.type.filter_variable(np.ones((1, 1)))
# TensorConstant{(1, 1) of 1.0}

X.type.is_valid_value(np.ones((1, 1)))
# True

I think what's happening here is that broadcastable is not an isolated property of a Type, but something about how Types behave in Ops that allow for input broadcasting. When you add two matrices of shape (1, m), there is no broadcasting happening, so it's valid to use inputs that are not allowed to broadcast.

In other words, something must have a shape of 1 to be broadcastable but can still have a shape of 1 and not be required to broadcast. The flags were not really about the Types but what you could do with the Types. In practice you could disallow Numpy broadcasting behavior in Theano (or more precisely, Numpy broadcasting behavior was an opt-in deal).

I don't think this is useful for users, but was clearly used by the devs to produce valid gradient graphs.

I don't think we should offer the option to disable Numpy broadcasting at run-time, but we should offer the users the tools to tell Aesara when broadcasting won't be needed so that it can produce simpler graphs.

We can even raise if users said something wouldn't be broadcastable but has a shape of 1. That will make Aesara much more intuitive.

Here are some illustrations of Elemwise's strict and inconsistent interpretation of TensorType.broadcastable values:

Y = tt.matrix("Y")

XY_fn = theano.function([X, Y], X * Y)

theano.printing.debugprint(XY_fn, print_type=True)
# Elemwise{mul,no_inplace} [id A] <TensorType(float64, matrix)> ''   0
#  |X [id B] <TensorType(float64, matrix)>
#  |Y [id C] <TensorType(float64, matrix)>

XY_fn(np.ones((2, 2)), np.ones((1, 2)))
# ValueError: Input dimension mis-match. (input[0].shape[0] = 2, input[1].shape[0] = 1)

Everything checks out at the Type-level with those arguments, but the call fails in Op.perform (i.e. run-time), and easily avoidable run-time errors aren't good for anything.

Let's try some other arguments that conflict with the False-means-not-one interpretation:

XY_fn(np.ones((1, 2)), np.ones((1, 2)))
# array([[1., 1.]])

Notice how Elemwise is selectively interpreting a False Type.broadcastable value as an indication that the corresponding shape value is not equal to one, and the check/assumption is only ever effectively applied to dimensions that are actively being broadcasted.

Again this makes sense if you read it not as saying something about the shape of the inputs, but whether broadcasting is actually needed to combine them. In this case it wasn't, so it's valid to use non-broadcastable (according to Theano rules) inputs..

Unfortunately, the changes introduced by the aforementioned PRs were not the end of this Elemwise refactoring. There are still some things to clean up. (Fortunately, doing so is also helping us identify outstanding gaps in our tests.)

We chose the latter. The former would've forced us to fix every other Op.make_node and Op.infer_shape except Elemwise (including non-Aesara Ops), and the latter forces us to fix Elemwise and its rewrites–as well as any other Ops/rewrites that make the same assumptions as Elemwise, and doesn't appear to be many, given that the vast majority of things still function correctly after these changes.

Again, at the design level, our choice does not prevent us from attaining any of the same performance benefits as the former. We just need to keep moving in the direction we're headed and establish a means of specifying and using more granular static shape constraints, then we can have all the same optimizations as before and more.

I think we might need to allow for more granular static types (in particular, shape>1), which will nevertheless involve quite some work in make_node and rewrites.

I still want to see how the ifelse fares. I would be very happy if this all turned out to be a non-issue.

ricardoV94 commented 2 years ago

Ill try and open a PR to add tests for the bugs in the grad of Elemwise, Gemm, amd BroadcastTo (those are all the Ops I know that allow for runtime broadcasting)

brandonwillard commented 2 years ago

I hope the "appeal to naive reversions" wasn't directed at me.

It was directed at all of us.

brandonwillard commented 2 years ago

I don't think this is useful for users, but was clearly used by the devs to produce valid gradient graphs.

Although we're basically guessing why things were done in Theano, I have a strong feeling that, given the status of IfElse near the end of Theano's development, the reason a kludge like Elemwise's use of TensorType.broadcastable was employed probably has more to do with Theano's lack of support for conditionals.

Put another way, they might have wanted to go down a path like the one we're going down now, but IfElse wasn't available, finished, or guaranteed to be evaluated efficiently (i.e. lazily).

aseyboldt commented 2 years ago

I gave the ElemwiseSum approach a bit more thought, and the more I think about it the more I like it :-)

Just to make sure we are all on the same page, here is a summary of how that would go:

So the remaining question would be how to implement the Elemwise with an included reduction. It turns out we can reuse the numpy broadcasting rules to do exactly that. The following implementation will already do what we want:

@numba.njit
def elemwise_sum(a, b, out_shape):
    out = np.empty(out_shape, dtype=out_dtype)
    a, b, out = np.broadcast(a, b, out)
    for x, y, z in np.nditer((a, b, out)):
        z[()] += elemwise_func(x, y)
    return out

Why does this work? If out is broadcasted along a dimension, the broadcasted version of out will have stride set to zero in that axis, so in the iteration we add to the same element of out multiple times.

However, the compiler doesn't know at compile time what the stride of the array will be, so it produces code that works for all strides, so it won't be particularly fast for any. In most applications we only really care about the inner most loop however, so we can specialize that inner loop to make sure we can vectorize it:

import numba
import numpy as np

# Get information about simd vectorization from llvm
#import llvmlite.binding as llvm
#llvm.set_option('', '--debug-only=loop-vectorize,iv-descriptors')

@numba.extending.intrinsic
def address_as_void_pointer(typingctx, src):
    """ returns a void pointer from a given memory address """
    from numba.core import types, cgutils
    sig = types.voidptr(src)

    def codegen(cgctx, builder, sig, args):
        return builder.inttoptr(args[0], cgutils.voidptr_t)
    return sig, codegen

@numba.extending.intrinsic
def make_writable(typingctx, src):
    """Make a readonly array writable."""
    from numba.core import types, cgutils
    sig = src.copy(readonly=False)(src)

    def codegen(cgctx, builder, sig, args):
        (arg,) = args
        return numba.core.imputils.impl_ret_borrowed(cgctx, builder, sig.return_type, arg)
    return sig, codegen

# fastmath=True (is "afn" enough?) will use faster but less
# precise special functions if SVML is installed
# (conda install -c numba icc_rt)
#@numba.njit(fastmath={"afn"})
@numba.njit(fastmath=True)
def elemwise_func(a, b):
    return np.exp(a + b)

# Slow reference implementation
@numba.njit
def elemwise_sum_slow(a, b, out):
    out[...] = 0
    a, b, out = np.broadcast_arrays(a, b, out)

    for x, y, z in np.nditer((a, b, out)):
        z[()] += elemwise_func(x, y)

# Fast version that tries to use SIMD in some cases
@numba.njit(fastmath={"reassoc", "contract"})
def elemwise_sum(a, b, out):
    # TODO There are cases where we don't need this.
    # Can we somehow put it into the loop?
    out[...] = 0

    a, b, out = np.broadcast_arrays(a, b, out)

    # Numpy returns readonly arrays by default
    # so that users don't get surprised by the
    # arrays with stride 0. It is safe to write
    # howerver, in this case we *want* the behaviour
    # of strides == 0.
    a = make_writable(a)
    b = make_writable(b)
    out = make_writable(out)

    # Last loop is reduction and a and b are contiguous
    if (
        out.strides[-1] == 0
        and a.strides[-1] == 8  # TODO use dtype.itemsize
        and b.strides[-1] == 8
    ):
        for idx in np.ndindex(out.shape[:-1]):
            tmp = 0
            x = a[idx]
            y = b[idx]
            k = 0
            N = a.shape[-1]
            while k < N:
                tmp += elemwise_func(x[k], y[k])
                k += 1
            if N > 0:
                out[idx][0] += tmp

    # Last loop is reduction, but a or b are not contiguous
    elif out.strides[-1] == 0:
        for idx in np.ndindex(out.shape[:-1]):
            tmp = 0.
            x = a[idx]
            y = b[idx]
            k = 0
            N = a.shape[-1]
            while k < N:
                tmp += elemwise_func(x[k], y[k])
                k += 1
            if N > 0:
                out[idx][0] += tmp
    # All last loops are contiguous
    elif (
        a.strides[-1] == 8
        and b.strides[-1] == 8
        and out.strides[-1] == 8
    ):
        for idx in np.ndindex(out.shape[:-1]):
            x = a[idx]
            y = b[idx]
            z = out[idx]
            N = a.shape[-1]
            # Tell numba that the array is c contiguous
            z = numba.carray(address_as_void_pointer(z.ctypes.data), N, dtype=out.dtype)
            x = numba.carray(address_as_void_pointer(x.ctypes.data), N, dtype=x.dtype)
            y = numba.carray(address_as_void_pointer(y.ctypes.data), N, dtype=y.dtype)
            k = 0

            while k < N:
                z[k] = elemwise_func(x[k], y[k])
                k += 1
    else:
        # This last one doesn't seem to get vectorized at all,
        # but I'm not sure why exactly...
        for idx in np.ndindex(out.shape[:-1]):
            x = a[idx]
            y = b[idx]
            z = out[idx]
            N = a.shape[-1]
            assert z.strides[-1] > 0
            assert x.strides[-1] >= 0
            assert y.strides[-1] >= 0
            k = 0

            while k < N:
                z[k] = elemwise_func(x[k], y[k])
                k += 1

# Pure python impl using np.nditer
def elemwise_sum_python(a, b, out):
    out[...] = 0
    it = np.nditer(
        (a, b, out),
        flags=["external_loop", "reduce_ok", "buffered"],
        op_flags=[["readonly"], ["readonly"], ["readwrite"]]
    )
    with it:
        for x, y, z in it:
            #z[...] += np.exp(x + y)
            z[...] += elemwise_func(x, y)

This still isn't perfect, but it usually beats other implementations, often by a factor of 5. It could still be improved for cases where the inner loop is very small but the next inner loop is contiguous (or a reduction) as well, or if the input arrays have fortran order.

Some timings on my machine with icc_rt from numba installed:

image image image

I haven't given the question of how that would work in jax any thought yet...

brandonwillard commented 2 years ago

@aseyboldt, in the development of Aesara we need to assign greater value to the use/reuse/extension of existing and more general (sub)systems and features. The custom Op approach you're advocating, and its very specialized Numba transpilations, can serve as a sort of target for the end results of those systems and features, but we can't develop past/around them; otherwise, they'll end up becoming less and less useful over time.

As I stated above, we need the same level of shape inference and ability to assign constraints that would allow us to produce efficient graphs using existing Ops in many other situations.

In other words, how can we produce similarly optimized Numba code for the graphs we could generate via rewriting that use existing Ops?

aseyboldt commented 2 years ago

Sorry, but what is so specialized about an op that does elemwise reduction? It is a basic thing we want all over the place, and it can't really be built from smaller parts either.

brandonwillard commented 2 years ago

It is a basic thing we want all over the place, and it can't really be built from smaller parts either.

We need to establish that definitively first.

aseyboldt commented 2 years ago

Logp graphs are one big reduction, because we want a single logp value. And storing all intermediate values is known to be slow (memory pressure and all that). Just have a quick look at the timings above if you don't believe me or do some experiments yourself.

If you have ideas about how to split that apart into separate ops, I'm really curios how you'd even approach that so that the compiler can still use simd.

brandonwillard commented 2 years ago

Now come on, logp graphs are one big reduction, because we want a single logp value. And storing all intermediate values is known to be slow (memory pressure and all that). Just have a quick look at the timings above if you don't believe me or do some experiments yourself.

I think we're talking about different things, because I'm refering to the graphs produced by Elemwise.grad, which do not determine the memory usage patterns. Those graphs undergo a few more rewrite passes that ultimately determine that.

My concerns are that we, for example, prematurely use a custom Op that represents something like at.sum(x, axis=a), and this prevents us from reasoning about/with the resulting node as if it were the at.sum that it actually is.