aesara-devs / aesara

Aesara is a Python library for defining, optimizing, and efficiently evaluating mathematical expressions involving multi-dimensional arrays.
https://aesara.readthedocs.io
Other
1.18k stars 155 forks source link

Wrong gradients when inputs are dynamically broadcasted #1089

Open ricardoV94 opened 2 years ago

ricardoV94 commented 2 years ago

This bug is an unexpected consequence of https://github.com/aesara-devs/aesara/pull/928 and rewrites that make certain assumptions: https://github.com/aesara-devs/aesara/issues/1089#issuecomment-1291561804

import aesara
import aesara.tensor as at
import numpy as np

x_row = at.row("x_row")
x_matrix = at.matrix("x_matrix")
y = at.matrix("y")

x_row_grad = at.grad(at.sum(x_row + y), wrt=x_row)
x_matrix_grad = at.grad(at.sum(x_matrix + y), wrt=x_matrix)

f_row = aesara.function([x_row, y], x_row_grad)
print(f_row(np.ones((1, 5)), np.ones((5, 5))))
# [[5. 5. 5. 5. 5.]]

f_matrix = aesara.function([x_matrix, y], x_matrix_grad)
print(f_matrix(np.ones((1, 5)), np.ones((5, 5))))
# [[1. 1. 1. 1. 1.]
#  [1. 1. 1. 1. 1.]
#  [1. 1. 1. 1. 1.]
#  [1. 1. 1. 1. 1.]
#  [1. 1. 1. 1. 1.]]

The faulty logic is found here: https://github.com/aesara-devs/aesara/blob/7f4e0ab443826af7f3f48d77bf13027ab6bdff69/aesara/tensor/elemwise.py#L552-L570

This is also likely a problem in the grad of BroadcastTo which calls infer_broadcastable and which defaults to assuming something will not have broadcasted if a static shape of 1 can't be inferred.

https://github.com/aesara-devs/aesara/blob/7f8af9bc28755d93dca3afff2534a8a5f5ecbd80/aesara/tensor/extra_ops.py#L1593

https://github.com/aesara-devs/aesara/blob/7f8af9bc28755d93dca3afff2534a8a5f5ecbd80/aesara/tensor/basic.py#L1313

And also GEMM since #986

I am not sure if there's a good solution to this problem, as we would need an expression with different output shapes depending on whether the runtime inputs are broadcasted or not.

Solution might look something like: https://github.com/aesara-devs/aesara/issues/1089#issuecomment-1202225638

brandonwillard commented 2 years ago

Also, I still think the functionality of your proposed custom ElemwiseSum and/or extension to Elemwise overlaps with CAReduce, so we need to consider all this in that context first.

ricardoV94 commented 2 years ago

The ElemwiseSum sounds like a useful specialization Op in itself, specially in conjunction with the Composite Op, for graphs that end up being reduced in the output.

The fact that it might also solve the grad issue is neat.

The logic shown for Numba and Python mode seems equally easy to port to the existing C code.

It sounds more manageable (and performant) than the IfElse approach at this point.

For JAX, I don't think there is any easy approach that will work, given JIT limitations. They can definitely generate grad graphs with broadcasted inputs but I don't know how they do it, and specially if it's compatible with the way we build graphs.

brandonwillard commented 2 years ago

It sounds more manageable (and performant) than the IfElse approach at this point.

The approaches are not mutually exclusive; that's the big issue with this discussion, and solving this particular issue with only a custom Op that unnecessarily hides its simple conditional-sum form prevents use of existing and future rewrites and graph comparisons. This was clearly undesirable in Theano (see the old contributor/dev. documentation) and it still is.

Also, this discussion is conflating performance-related considerations with the representation of computations in our graphs. Custom Op-using nodes can always replace sub-graphs of basic Ops during the specialization passes when/if that's necessary or advantageous, so, if one is going to argue for more general use of a custom Op, they really need to show why/how such replacements aren't reasonable.

Regardless, the benefits of improving our shape inference and adding/handling constraints are many, and the results are more accurate compile-time information, which is what allows us to produce more performant transpiled results. If you don't want to help develop toward that right now and would prefer a narrower, shorter-term solution, that's your choice, but let's not convince ourselves and others that such solutions are anything more than they actually are (i.e. different ways to do the same things, just with different priorities).

ricardoV94 commented 2 years ago

Just to recap, these are the options that have been proposed:

  1. IfElse graph, possibly wrapped in an OpFromGraph for easier manipulation in graph rewrites and more readable graph printing
  2. Extend CAReduce to have symbolic axis. This is fine as long as keepdims=True, otherwise we wouldn't know the number of dimensions at compile time, which Aesara requires. However I think keepdims is currently handled outside of CAReduce, by DimShuffles?
  3. Custom SumToShape Op
  4. Extend (or specialize) Elemwise so that it integrates the behavior of SumToShape in its perform method (Elemwise + Identity = SumToShape)
  5. Extend CAReduce so that it integrates the behavior of SumToShape in its perform method

It seems that @brandonwillard is favoring 1 (or possibly 5?) at the moment and me and @aseyboldt are favoring 3/4.

How shall we proceed?

brandonwillard commented 2 years ago

Just to recap, these are the options that have been proposed:

  1. IfElse graph, possibly wrapped in an OpFromGraph for easier manipulation in graph rewrites and more readable graph printing
  2. Extend CAReduce to have symbolic axis. This is fine as long as keepdims=True, otherwise we wouldn't know the number of dimensions at compile time, which Aesara requires. However I think keepdims is currently handled outside of CAReduce, by DimShuffles?
  3. Custom SumToShape Op
  4. Extend (or specialize) Elemwise so that it integrates the behavior of SumToShape in its perform method (Elemwise + Identity = SumToShape)
  5. Extend CAReduce so that it integrates the behavior of SumToShape in its perform method

It seems that @brandonwillard is favoring 1 (or possibly 5?) at the moment and me and @aseyboldt are favoring 3/4.

How shall we proceed?

Option 1 should be the most straightforward way to fix the problem right now, but it leads to graph optimization concerns. Those concerns can be addressed by reducing/removing the additional IfElse sub-graphs using shape inference and a means of adding and tracking "shape > 1" constraints.

Options 3 and 4 are just other ways to represent the IfElse sub-graphs. They give rise to the same optimization concerns as the IfElse sub-graphs themselves, and they add to those concerns by either adding a new Op with no rewrite coverage or changing an existing Op so that all the code touching it (e.g. rewrites) needs to be updated/reconsidered.

Option 5 is integral to 3 and 4 and must be considered first, because all of @aseyboldt's Numba and performance comments and examples could easily apply to CAReduce and only that. Basically, all the custom Op performance talk is probably just an independent discussion about improvements to CAReduce's Numba implementation.

If for some reason it's not, then we can consider adding a custom Op that replaces the IfElse sub-graphs introduced by Option 1 during specialization.

aseyboldt commented 2 years ago

Maybe we should pick apart a few of the things I posted a bit, I might not have distinguished different ideas enough :-)

Maybe I just don't get option 5 yet however. How would that work?

brandonwillard commented 2 years ago

Maybe we should pick apart a few of the things I posted a bit, I might not have distinguished different ideas enough :-)

  • First, there was the idea to have a custom op to replace the IfElse constructs and figure out the details of the broadcast. This could be used to build a solution, but it has nothing to do with the ElemwiseSum approach, which wouldn't need this custom op.

If ElemwiseSum can be used in place of the IfElse graphs to solve this issue, then it is a custom Op that replaces the IfElse graphs. If the use of ElemwiseSum also requires the IfElsees, then it's not really a solution to this issue.

  • For the ElemwiseSum approach the basic observation was that we can use the broadcasting logic in numpy/numba directly to write a function that does the correct reduction, with some pretty short and simple code (the elemwise_sum_slow implementation above). So just like we don't have to figure out in aesara what broadcasting is doing exactly to implement the elemwise because we just use the logic in numba, we can do the same for the gradient. I'm pretty happy with that basic idea, but I was worried about how performance here, because the elemwise_sum_slow version is, well, slow.

Writing an Op or Numba function that can perform these operations is not the problem we need to solve in this issue. The Ops needed to do this already exist, and, by using them, we get downstream shape inference, gradients, rewrites, etc. It doesn't get much simpler than that.

Your ElemwiseSum would need to implement all the same graph-level logic as the combination of those existing Ops and rewrites, so it's definitely not simpler in any of the ways that matter.

  • So this is where the specialization comes in that should solve the speed issues for the vast majority of cases. It also doesn't explicitly figure out the broadcasting patterns, it just specializes cases like "the inner loop is a reduction" and "the input arrays are continuous". It is true that we could use those approaches to also speed up CAReduce, but I don't think CAReduce and ElemwiseSum are the same thing. ElemwiseSum could represent np.exp(x).sum(), but CAReduce can't, or am I missing something?

Maybe I just don't get option 5 yet however. How would that work?

You are talking about optimizations relevant to CAReduce, and, yes, it is used to implement graphs that compute np.exp(s).sum():

import aesara
import aesara.tensor as at

x = at.vector("x")
y = at.exp(x).sum()

aesara.dprint(y)
# Sum{acc_dtype=float64} [id A]
#  |Elemwise{exp,no_inplace} [id B]
#    |x [id C]

type(y.owner.op).mro()
# [aesara.tensor.math.Sum,
#  aesara.tensor.elemwise.CAReduceDtype,
#  aesara.tensor.elemwise.CAReduce,
#  aesara.link.c.op.COp,
#  aesara.graph.op.Op,
#  aesara.graph.utils.MetaObject,
#  aesara.link.c.interface.CLinkerOp,
#  aesara.link.c.interface.CLinkerObject,
#  object]

Your ElemwiseSum is a very specific case of CAReduce. Given a CAReduce node, one should be able to generate the same Numba code as your examples. In other words, all the necessary compile-time information is explicitly represented in the CAReduce graphs we already produce.

Anyway, if you can use this issue—and the specialized Numba functions you've written for it—to improve CAReduce, that's the kind of help we need.

ricardoV94 commented 2 years ago

@aseyboldt again I have the feeling the ElemwiseSum may make sense as a specialization Op that you introduce in graphs that have sum + elemwise (and you don't need the output of elemwise), but that can be pursued almost independently of how we solve the grads.

Back to the grads. I think it's valid to ask why do we hide broadcasting logic inside the Elemwise Op perform method, but would not be willing to do the same with "unbroadcasting” (either in Elemwise, CAReduce or a new Op)?

After all, we could have decided to define the whole dynamic broadcasting of Elemwises with IfElses (or explicit calls to BroadcastTo) at the graph level, whenever we don't know if the inputs of an Elemwise will have to be broadcasted.

I think it's fair to say that we did this because most graphs would be unnecessarily complex and probably less performant. Also the broadcasting logic is trivial to implement in the perform method (just set strides of degenerate input dims to 0). This also brought many specific issues onto rewrites that we are still tweaking, but I don't sense any desire to move such logic to the graph level to avoid worrying about rewrites.

I think the very same argument works for the unbroadcasting behavior, which will be almost as fundamental an operation in an autodiff library that allows for dynamic/ runtime broadcasting.

Also, as @aseyboldt showed, the logic is similarly trivial to achieve (just set the strides of the degenerate output dimensions to zero and use inplace addition instead of setting)

Edit: removed one paragraph.

brandonwillard commented 2 years ago

@aseyboldt again I have the feeling the ElemwiseSum may make sense as a specialization Op that you introduce in graphs that have sum + elemwise (and you don't need the output of elemwise), but that can be pursued almost independently of how we solve the grads.

In order to do exactly what ElemwiseSum seems to imply, we'll probably need another type of "fusion" rewrite similar to FusionOptimizer, because CAReduce will need Composite scalar_ops like sum(exp(x), exp(y)).

Back to the grads. I think it's valid to ask why do we hide broadcasting logic inside the Elemwise Op perform method, but would not be willing to do the same with "unbroadcasting” (either in Elemwise, CAReduce or a new Op)?

We aren't "hiding" broadcasting logic. Elemwise has always performed some of the broadcasting itself, and we've only started to allow that to work without the strong compile-time broadcasting assumptions that led to numerous Theano bugs and limitations. See my earlier explanation of those.

Don't forget, Elemwise represents a subset of NumPy operations that inherently perform broadcasting, and, in matching NumPy's interface and design, we naturally assign Elemwise the responsibility of broadcasting. This is the historical design decision from which we work.

After all, we could have decided to define the whole dynamic broadcasting of Elemwises with IfElses (or explicit calls to BroadcastTo) at the graph level, whenever we don't know if the inputs of an Elemwise will have to be broadcasted. I think it's fair to say that we did this because most graphs would be unnecessarily complex and probably less performant. Also the broadcasting logic is trivial to implement in the perform method (just set strides of degenerate input dims to 0).

If you're asking me, it was really due to the reasons I mentioned above; otherwise, If there end up being reasons to move Elemwise's broadcasting to the graph-level, we should consider doing it.

Complexity wouldn't necessarily be an issue, and, if it were, the question is whether or not we want to pay the costs of dealing with it. If the costs can be shared with other things we want to accomplish (e.g. general rewriting improvements, better graph printing capabilities, etc.), then we might be able to justify them.

The same goes for performance concerns; if we can share the costs of making something more performant with our other goals, then we get more value for our efforts. For instance, if we're already investing in our ability to succinctly represent and detect sub-structures within our graphs, then we're also making it easier for us to track the sub-graphs that represent these computations and replace them with whatever low-level implementations we want.

This also brought many specific issues onto rewrites that we are still tweaking, but I don't sense any desire to move such logic to the graph level to avoid worrying about rewrites.

I've already addressed the fact that we are in the midst of ongoing work and why, and, as far as I can tell, moving the broadcasting logic in Elemwise to the graph wouldn't have helped with any of the issues we've discussed.

I think the very same argument works for the unbroadcasting behavior, which will be almost as fundamental an operation in an autodiff library that allows for dynamic/ runtime broadcasting.

Which argument? I don't see how hypothetically moving the broadcasting in Elemwise to the graph works as an argument for changing the basic definition of an existing Op. Same for the Elemwise changes that we have made; those didn't make Elemwise do anything outside of its basic scope of definition (i.e. as an Op that performs a scalar computation element-wise); however, making Elemwise perform a sum and/or collapse dimensions absolutely does change its definition. Worse yet, such a change would introduce a redundancy between Elemwise and CAReduce. Why would we want that?

We already know that this issue can be fixed by using IfElse, and that the resulting graphs can be completely removed at compile-time when we implement constraints. Any alternative approach will need to be able to do effectively the same thing (i.e. remove similar run-time logic that's obviated by compile-time information), and I don't see how changes to Elemwise would help with any of that.

Really, I need some concrete details about these Elemwise "unbroadcasting" changes in order to say anything more specific.

Also, as @aseyboldt showed, the logic is similarly trivial to achieve (just set the strides of the degenerate output dimensions to zero and use inplace addition instead of setting)

The stride-based logic looks great for our CAReduce implementations, but it doesn't help us determine how we should represent the missing logic of this issue at the graph-level, nor do I see how it justifies adding "unbroadcasting" to Elemwise.

In order to help get this issue back on track, let's move the Numba-based performance-related material to a new issue/Discussion about enhancements to CAReduce. That work could end up being more important than this issue, and we don't need the extra noise in either.

Otherwise, if I'm still missing something important regarding the Elemwise changes to which you've been alluding, then you might need to elaborate on the relevant details.

aseyboldt commented 2 years ago

We can certainly leave the performance considerations of ElemwiseSum to the side for now, I just thought it's important to make sure this can be implemented in a performant way.

Is your point about CAReduce that you would want to do this fully general right away, and make this a ElemwiseCAReduce, instead of the special case of a sum at first?

Maybe I'm still not entirely clear on how this would work with just ifelse. The best I could come up with for now would look something like this (mostly pseudocode).

def is_broadcasted(input_shape, output_shape, axis):
    n_out = len(output_shape)
    n_in = len(input_shape)
    assert axis >= 0
    assert axis < n_out
    assert n_out >= n_in

    if axis < n_out - n_in:
        return True

    shape_out = output_shape[axis]
    shape_in = input_shape[axis - (n_out - n_in)]

    return at.neq(shape_out, shape_in)

class Elemwise:
    def grad(...):
        d_inputs = []

        for input in inputs:
            # Compute the derivative with respect to
            # the input without final accumulation
            d_input = ...

            for axis in range(len(output_shape)):
                d_input = aesara.ifelse.ifelse(
                    is_broadcasted(input.shape, output.shape, axis),
                    d_input.sum(axis=axis, keepdims=True),
                    d_input,
                )
            # Or maybe this? I think this would require deep changes to CAReduce, because here
            # the axis is a tensor now, and currently it is a static Op attribute.
            # broadcasted = [is_broadcasted(input.shape, output.shape, axis) for axis in range(len(output_shape))]
            # d_input = at.sum(d_input, axis=at.as_tensor_variable(broadcasted).nonzero(), keepdims=True)

            d_input = d_input.reshape(input.shape)
            d_inputs.append(d_input)

        return d_inputs

Is something like this roughly what you had in mind, or am I missing a solution here?

aseyboldt commented 2 years ago

I think we should also consider a sixth option:

Move back to something pretty close to the original theano semantics: We could allow broadcasting only if the broadcasted dimension has a fixed and known shape of 1. We would then have to add a runtime check to make sure we do not broadcast in Elemwise and similar ops if the shape happens to be 1 but was not statically 1.

This would correspond to something like this (again, more pseudocode than anything else):

class TensorType:
    @property
    def broadcastable(self):
        return [not isinstance(length, int) or length != 1 for length in self.shape]

# TODO I'm sure we can improve the implementation. Must be called in all broadcasting op impls.
@numba.njit
def check_unexpected_broadcast(shapes: list[tuple[int]], is_broadcastable: list[tuple[int]]):
    ndim = max(len(shape) for shape in shapes)
    shapes = [(1,) * (ndim - len(shape)) + shape for shape in shapes]
    is_broadcastable = [(True,) * (ndim - len(bc)) + bc for bc in is_broadcastable]

    broadcasted = [
        any(max(shape[i]) != min(shape[i]) for shape in shapes)
        for i in range(ndim)
    ]

    if any(broadcasted[i] and not broadcastable[i] for i in range(ndim) for broadcastable in is_broadcastable):
        raise ValueError()

class Elemwise:
    def grad(...):
        d_inputs = []

        for d_input in d_inputs:
            d_input = d_input.sum(axes=[
                i for i in range(d_input.ndim)
                if not output.broadcastable[i + output.ndim - d_input.ndim] and input.broadcastable[i]
            ], keepdims=True)
            d_input = d_input.sum(axes=list(range(output.ndim - d_input.ndim)))
            d_inputs.append(d_input)

        return d_inputs
brandonwillard commented 2 years ago

Move back to something pretty close to the original theano semantics: We could allow broadcasting only if the broadcasted dimension has a fixed and known shape of 1. We would then have to add a runtime check to make sure we do not broadcast in Elemwise and similar ops if the shape happens to be 1 but was not statically 1.

It sounds like that would severely restrict the kinds of computations/graphs we would be able support, and produce mirror images of the previous issues (see https://github.com/aesara-devs/aesara/issues/1089#issuecomment-1207111968). Remember, we actually need to support ambiguous static shape/broadcasting information; otherwise, we put unreasonable pressure on every Op, rewrite, etc., to propagate that information perfectly.

Is your point about CAReduce that you would want to do this fully general right away, and make this a ElemwiseCAReduce, instead of the special case of a sum at first?

We've been talking about two independent types of potential CAReduce-specific improvements that, if implemented together in a single Op, would be significantly less useful. I've created issues for each (i.e. https://github.com/aesara-devs/aesara/issues/1116 and https://github.com/aesara-devs/aesara/issues/1115), so we can continue the discussion(s) in those.

Maybe I'm still not entirely clear on how this would work with just ifelse. The best I could come up with for now would look something like this (mostly pseudocode).

See https://github.com/aesara-devs/aesara/issues/1089#issuecomment-1205615007. That IfElse approach can be used in Elemwise.grad right now to close this issue. All other relevant questions are about the performance price we pay for computing that extra conditional logic at run-time. Previously, that logic was performed at "compile-time" within Elemwise.grad, but it relied on inconsistent static shape assumptions. Now, it needs to be performed at run-time—until we can once again remove it at compile-time using rewrites, a direct means of encoding the requisite information within TensorType.shape, etc.

ricardoV94 commented 2 years ago

This helper will probably also have too be updated when/if the ops that use in in the grad allow for runtime broadcasting: https://github.com/aesara-devs/aesara/blob/3500fec88afdde7661e9a7e5c0e636e08592bebb/aesara/tensor/subtensor.py#L1893-L1924

ricardoV94 commented 1 year ago

Weighing in after sometime passed. I am quite convinced that allowing for dynamic broadcasting (i.e., not known until runtime), or at least allowing it by default is not worth the trouble and comes at the cost of having to restrict rewrites and generate less efficient code for what are the most common use cases of aesara (tensors with unknown shapes that are not broadcastable at runtime).

This understanding was clearly missing (at least for me) when we did the switch from type broadcastable to shape (they are not equivalent) and from not allowing runtime broadcasting to allowing.

Broadcastable and non-broadcastable graphs are fundamentally different, and it makes sense to have them represented differently symbolically. If we really want to allow for the general case I think it's better to force users to specify so manually and then it's fine to use ifelse or what have for the gradients.

brandonwillard commented 1 year ago

Weighing in after sometime passed. I am quite convinced that allowing for dynamic broadcasting (i.e., not known until runtime), or at least allowing it by default is not worth the trouble and comes at the cost of having to restrict rewrites and generate less efficient code for what are the most common use cases of aesara (tensors with unknown shapes that are not broadcastable at runtime).

This understanding was clearly missing (at least for me) when we did the switch from type broadcastable to shape (they are not equivalent) and from not allowing runtime broadcasting to allowing.

Broadcastable and non-broadcastable graphs are fundamentally different, and it makes sense to have them represented differently symbolically. If we really want to allow for the general case I think it's better to force users to specify so manually and then it's fine to use ifelse or what have for the gradients.

Open a discussion summarizing your opinion, and make sure it actually covers all the relevant material/points (e.g. all the old bugs and restrictions we'll have to accept again, show that there actually is a real and unacceptable performance trade-off currently, etc.)

Aside from that, we don't need any more random hot-takes on this subject, especially since they're off topic and they've been stealing the focus from this actual issue.

ricardoV94 commented 1 year ago

Open a discussion summarizing your opinion, and make sure it actually covers all the relevant material/points (e.g. all the old bugs and restrictions we'll have to accept again, show that there actually is a real and unacceptable perform trade-off currently, etc.)

Given that things have been broken by the switch, I would say the burden is still in justifying why we really need the changes. They were clearly not discussed in depth when we introduced them. Reverting the changes is a very possible solution to this issue, and as such relevant here.

Happy to explore the pros-and-cons in a separate discussion, but I don't like the anchoring of the discussion where dynamic broadcasting is the baseline and has to be argued against. Also this was the very opposite of a hot-take, I have been pondering it since I opened the issue nearly two months ago.

brandonwillard commented 1 year ago

To clarify the relationship between this Aesara issue and Theano's old broadcasting assumptions/TensorType.broadcastable interpretation, consider this issue's example in Theano:

import theano
import theano.tensor as tt
import numpy as np

theano.__version__
# '1.0.5'

X_row = tt.row("X_row")
X_matrix = tt.matrix("X_matrix")

def X_grad_fn_constructor(X):
    Y = tt.matrix("Y")
    X_sum = tt.sum(X + Y)
    X_grad = tt.grad(X_sum, wrt=X)
    X_grad_fn = theano.function([X, Y], X_grad)
    return X_grad_fn

X_grad_fn_row = X_grad_fn_constructor(X_row)
X_grad_fn_matrix = X_grad_fn_constructor(X_matrix)

# This input is broadcastable in the first dimension, but the `Type`-level
# representation of that fact is lacking in the `X_matrix` case.  Let's see how
# Theano handles this broadcast information disparity.
# To be clear, *both* cases should theoretically return the same values for the
# same inputs.
X_val = np.ones((1, 5))
Y_val = np.ones((5, 5))

# The "row"-`Type` case (i.e. we tell Theano that the first dimension of `X` is
# broadcastable)
row_res = X_grad_fn_row(X_val, Y_val)
row_res
# array([[5., 5., 5., 5., 5.]])

# The "matrix"-`Type` case (i.e. we *don't* tell Theano that the first
# dimension of `X` is actually broadcastable)
matrix_res = X_grad_fn_matrix(X_val, Y_val)
matrix_res
# array([[1., 1., 1., 1., 1.]])

assert np.array_equal(matrix_res, row_res)
# AssertionError:

In other words, Theano's assumptions were not capable of solving the issue raised here. Instead, the same shape inference problem(s) simply took different forms. (N.B. This also means that no amount "reverting" will fix this issue.)

The thing we're trying to fix in this issue has always been an issue; only now (e.g. with TensorType.shape and the requisite logic clarifications) are we becoming equipped to sufficiently address it.

aseyboldt commented 1 year ago

I don't think this result has that much to do with the theano broadcasting rules, but it is an example of a rewrite that changes the results of a graph from an error to a value.

According to the old theano rules, computing X + Y, where X is not broadcastable, should result in an error, because they have incompatible shapes. This is in fact what happens if we just compute the sum:

x = tt.matrix("x")
y = tt.matrix("y")

(x + y).eval({x: np.ones((1, 3)), y: np.ones((3, 3))})
# ValueError

Just like any other incompatible shapes:

(x + y).eval({x: np.ones((2, 3)), y: np.ones((3, 3))})
# ValueError

If we now compute the gradient of (x + y).sum(), we see the exact same behavior in unoptimized graphs. If we allow more rewrites, it will optimize away the shape checks however, leading to incorrect results for any incompatible arrays:

z = (x + y).sum()
dz = theano.grad(z, x)

func_orig = theano.function([x, y], dz, mode="FAST_COMPILE")
func_rewrite = theano.function([x, y], dz, mode="FAST_RUN")

theano.printing.debugprint(func_orig)
# Elemwise{second} [id A] ''   1
#  |Elemwise{add,no_inplace} [id B] ''   0
#  | |x [id C]
#  | |y [id D]
#  |TensorConstant{(1, 1) of 1.0} [id E]

theano.printing.debugprint(func_rewrite)
# Alloc [id A] ''   2
#  |TensorConstant{(1, 1) of 1.0} [id B]
#  |Shape_i{0} [id C] ''   1
#  | |x [id D]
#  |Shape_i{1} [id E] ''   0
#    |x [id D]

func_orig(np.ones((2, 3)), np.ones((3, 3)))
# ValueError

func_rewrite(np.ones((2, 3)), np.ones((3, 3)))
# array([[1., 1., 1.],
#         [1., 1., 1.]])

In short, theano assumes that the gradient of x + y is ones_like(x), when in fact it is Assert(compatible(x, y))(ones_like(x))

brandonwillard commented 1 year ago

I don't think this result has that much to do with the theano broadcasting rules, but it is an example of a rewrite that changes the results of a graph from an error to a value.

I might not be clear on what you mean by "broadcasting rules" in this scenario, but, unless, removing a rewrite somehow makes the gradients work, I don't see how this point is relevant.

The concern here is whether or not Theano and its old broadcasting rules and/or use/interpretation of TensorType.broadcastable somehow made the buggy graph in this issue produce the correct values. Since Theano does not support this issue's case (rewrites or not), arguments about where/when/why an error is raised are not relevant.

Again, it would be nicer if Theano raised a more appropriate error earlier, but that doesn't change the fact that the way broadcasting, Elemwise, and/or shape inference was implemented in Theano does not support this case either.

According to the old theano rules, computing X + Y, where X is not broadcastable, should result in an error, because they have incompatible shapes. This is in fact what happens if we just compute the sum:

x = tt.matrix("x")
y = tt.matrix("y")

(x + y).eval({x: np.ones((1, 3)), y: np.ones((3, 3))})
# ValueError

I hope you weren't trying to show how Theano was somehow more consistent, because you've just demonstrated one of the fundamental inconsistencies in Theano's interpretation of TensorType.broadcastable: i.e. values that contradict a TensorType.broadcastable's strict interpretation are allowed. I demonstrated this in my original comment explaining some of Theano's inconsistencies, so please refer to that.

First, notice how the ValueError is raised by the Elemwise Op at run-time (i.e. during thunk evaluation), and not by Type-level checks that would support the consistency of Theano's TensorType.broadcastable interpretation. On a related note, here's something to think about: what would happen if Theano actually was consistent and didn't allow this contradiction?

Don't forget, this is an issue about the Type-level representation, interpretation, and use of "static" shape information. We want—and need—these things to be consistent at the Type-level in Aesara, and they demonstrably aren't in Theano, so we absolutely cannot use Theano as some sort of positive point of comparison.

By the way, you said that—according to Theano—they "have incompatible shapes", but they obviously don't. This statement alone reveals a serious design/representation issue.

At best, one could say that Elemwise was only partially implemented, or had an unnecessary restriction imposed upon it. Why was that restriction unnecessary? It could've—and should've—been implement at the Type-level and used to prevent the gradient complications while also supporting broadcasting in all other cases. This is how Elemwise works in Aesara, although we have yet to add the Type-level checks that would catch the bug/missing functionality in this issue earlier.

To clarify, we are currently using the same TensorType.broadcastable interpretations as Theano, and that's also the reason we're seeing the bug in this issue—that's also the reason that things still work in Aesara. Because of this, the idea of "going back to Theano's rules/interpretation/whatever" doesn't make any sense whatsoever.

All we've done with the recent Elemwise and TensorType.shape changes is make it possible to compute graphs that don't have "complete" static shape/broadcast information (i.e. for which we don't definitively know whether or not each dimension is broadcastable).

To put it another way, we've been extending the capabilities of Aesara and, as far as I can tell, we have yet to disable anything that used to work (in Theano or otherwise).

Just like any other incompatible shapes:

The rest of your comment confirms that Theano does not support this case, which was my entire point, so there's nothing to add.

aseyboldt commented 1 year ago

I might not be clear on what you mean by "broadcasting rules" in this scenario, but, unless, removing a rewrite somehow makes the gradients work, I don't see how this point is relevant.

That's what I'm saying (at least for the theano definition of "make it work") :-)

By "broadcasting rules" I mean how I understand the theano broadcasting rules (and have understood for several year of using theano):

We can consider two different versions of how addition (or other elemwise operations) are defined for tensors:

We now put information about which of those definitions we want to use into the type of a tensor: For each axis we can choose if we want to follow the numpy or the math definition by setting the corresponding broadcastable flag: If broadcastable is True, we require the values of that type to have shape 1 along that axis, and follow the numpy definition. If it is False, we allow any length, but follow the math rules for that axis. I don't think this is explicitly stated like this in the docs (at least I haven't found anything), but it certainly explained well how theano works (discounting bugs). They at least explicitly state that an axis that has broadcastable=False can have length 1: https://theano-pymc.readthedocs.io/en/latest/library/tensor/basic.html?highlight=broadcastable#theano.tensor.TensorType.broadcastable

So in your example X and Y have incompatible shapes, since broadcastable=False for the dimension where X and Y differed, and we therefore follow the math definition. This is why we see a shape error if we compute X + Y. But there is a bug (at least I'd argue it's a bug, there's no formal definition of what a rewrite is allowed to do for illegal inputs that i know of so I guess someone could argue that this is a valid rewrite, just invalid input data) that transforms the graph so that the result is changed from an error to a value.

(ie I think the rewrite from

theano.printing.debugprint(func_orig)
# Elemwise{second} [id A] ''   1
#  |Elemwise{add,no_inplace} [id B] ''   0
#  | |x [id C]
#  | |y [id D]
#  |TensorConstant{(1, 1) of 1.0} [id E]

to

theano.printing.debugprint(func_rewrite)
# Alloc [id A] ''   2
#  |TensorConstant{(1, 1) of 1.0} [id B]
#  |Shape_i{0} [id C] ''   1
#  | |x [id D]
#  |Shape_i{1} [id E] ''   0
#    |x [id D]

in incorrect, since the first errors out if x and y are incompatible (by the math definition), while the second one doesn't. So the rewrite change the output of the graph, if we consider the error as part of the output. Anyway, this doesn't really have anything to do with broadcasting, we have the exact same problem for other incompatible X and Y)

I think the theano definition of broadcastable is quite consistent and well defined (even though it wasn't written down properly anywhere that I know of), and keeps graphs nice and compact, and easy to understand, mainly because we always know just from the graph where broadcasting happens, and where it doesn't. It certainly did also lead to some confusion over the years as well however.

brandonwillard commented 1 year ago

To reiterate, this issue can be closed by

  1. raising an error at compile/construction-time stating that Elemwise.grad does not support cases with insufficient broadcast information, or
  2. making Elemwise.grad support cases with incomplete broadcast information.

This issue concerns itself with the relevant details of the above two approaches, and any new ones we may not have considered yet.

Conversations about TensorType.broadcastable, its relationship to the new TensorType.shape, etc., belong in https://github.com/aesara-devs/aesara/discussions/1170 or a new Discussion altogether.

ricardoV94 commented 1 year ago

Can we please stop marking comments as off-topic? Clearly this issue is at the level of a discussion and we can even make it into a discussion. Marking comments as off-topic helps nobody.

I think everyone is almost on the same page about Theano, except the question of why did they leave it to Elemwise and not Types to enforce broadcastable flags.

First, notice how the ValueError is raised by the Elemwise Op at run-time (i.e. during thunk evaluation), and not by Type-level checks that would support the consistency of Theano's TensorType.broadcastable interpretation.

It is now clear for me that only the Ops that perform broadcasting could enforce these broadcasting flags. broadcastable=False did not mean the tensors cannot have shape of 1, only that they could not be broadcasted in an operation that allows for broadcasting.

This constraint can only be assessed by looking at the Types of other inputs that go into an Elemwise Op (or any Op that performs broadcasting of the inputs internally) and their runtime shapes.

x = at.TensorType("float64", broadcastable=[False])("x")
y = at.TensorType("float64", broadcastable=[False])("y")
out = at.add(x, y)

out.eval({x:[0], y:[0]})  # Valid, as neither input had to be broadcasted
out.eval({x:[0, 1, 2], y:[0, 1, 2]})  # Valid, as neither input had to be broadcasted
out.eval({x:[0], y:[0, 1, 2]})  # Invalid, as x would have to be broadcasted

This is how Theano worked. Not saying that's the best, but I think it was coherent (as was other logic built on top of it which still lingers), and we need to understand it to be able to move this discussion forward.

rlouf commented 1 year ago

To be fair, from an outsider perspective, the discussion here is impossible to follow. The concerns that the issue raises are valid, but the discussion has since diverged to questions that are more fundamental and/or historical in nature. To move forward I suggest

  1. We open a discussion where we lay out exactly where we want to to go with shape inference and how we want to move forward generally;
  2. We can refer to the above discussion here to talk about the original issue specifically.
rlouf commented 1 year ago

I just ran @ricardoV94's example both on the current HEAD and on b60cf7240 (the commit before the changes in #928 were merged):

def test_ambiguous_broadcast():
    import aesara
    import aesara.tensor as at
    import numpy as np

    x_row = at.row("x_row")
    x_matrix = at.matrix("x_matrix")
    y = at.matrix("y")

    x_row_grad = at.grad(at.sum(x_row + y), wrt=x_row)
    x_matrix_grad = at.grad(at.sum(x_matrix + y), wrt=x_matrix)

    f_row = aesara.function([x_row, y], x_row_grad)
    row_res = f_row(np.ones((1, 5)), np.ones((5, 5)))

    f_matrix = aesara.function([x_matrix, y], x_matrix_grad)
    assert np.array_equal(f_matrix(np.ones((1, 5)), np.ones((5, 5))), row_res)

test_ambiguous_broadcast()

and it fails in both situations. This means that this issue is not a consequence of #928. In other words, reverting this change is not going to fix this bug.

Try this for yourself (don't forget to clear the cache), so we can all agree on this point moving forward. Comment "I see it" (and only that for now) if you can reproduce it, comment something else only if you can't reproduce it.

brandonwillard commented 1 year ago

I see it

ricardoV94 commented 1 year ago

@rlouf This is clearly explained in one of the "off-topic" comments as arising from an "inconsistent" rewrite: https://github.com/aesara-devs/aesara/issues/1089#issuecomment-1282776495

Running on b60cf7240

import aesara
import aesara.tensor as at
from aesara.compile.mode import Mode
import numpy as np

x_matrix = at.matrix("x_matrix")
y = at.matrix("y")

x_matrix_grad = at.grad(at.sum(x_matrix + y), wrt=x_matrix)

f_matrix = aesara.function(
    [x_matrix, y], 
    x_matrix_grad, 
    mode=Mode().excluding("local_fill_to_alloc"),
)
matrix_res = f_matrix(np.ones((1, 5)), np.ones((5, 5)))  # ValueError

It's one consequence of "rewrites assume original graphs were valid" mindset: https://theano-pymc.readthedocs.io/en/latest/tutorial/shape_info.html#shape-inference-problem (they mention the same applies to Elemwise after the example).


If you want an example that clear fails before that commit and passes after, make it just a bit more complex:

import aesara
import aesara.tensor as at
import numpy as np

x_matrix = at.matrix("x_matrix")
y = at.matrix("y")

x_matrix_grad = at.grad((at.sum(at.exp(x_matrix + y))), wrt=x_matrix)

f_matrix = aesara.function(
    [x_matrix, y], 
    x_matrix_grad, 
)
# ValueError before `b60cf7240` and not after
matrix_res = f_matrix(np.ones((1, 5)), np.ones((5, 5)))
rlouf commented 1 year ago

The reason I am writing this is newcomers to the issue will take your first comment as face value and assume #928 is the problem, and get confused if they try to run it with the commit before that. We're not writing these just for ourselves.

If the example you gave me indeed fails after #928 and not before could you please edit your original comment?

ricardoV94 commented 1 year ago

Thanks. I updated the original message to link to that comment.

brandonwillard commented 1 year ago

If the example you gave me indeed fails after #928 and not before could you please edit your original comment?

Doing that changes the entire nature of this issue, which was never originally about that ValueError, so let's not add that confusion to the mix.

It would make more sense to describe the errant rewrites in the opening comment.

rlouf commented 1 year ago

Thanks. I updated the original message to link to that comment.

Unless I'm mistaken you left the original code snippet that fails both after and before #928?

The reason I'm asking this is that issues / PRs in the repository are not only read by those who interacted with them. We should always keep this in mind and strive to make them understandable to anyone vaguely familiar with Aesara. I spent hours trying to understand what was going on in this issue before realizing yesterday that assertions in the original comment did not hold (namely that the particular example you provided fails because of #928). Most people will not do this legwork and take your original comment at face value. Ensues general confusion.