aesara-devs / aesara

Aesara is a Python library for defining, optimizing, and efficiently evaluating mathematical expressions involving multi-dimensional arrays.
https://aesara.readthedocs.io
Other
1.18k stars 155 forks source link

Generalize `linalg.det` beyond 2D arrays #1065

Open purna135 opened 2 years ago

purna135 commented 2 years ago
import numpy as np
import aesara.tensor as at

x = np.full((2, 3, 3), np.eye(3))
np.linalg.det(x)  # broadcast operation fine
at.linalg.det(x)  # AssertionError x.ndim == 2
guyrt commented 2 years ago

Happy to take this up, but I want to sanity check. Current implementation in aesara is using np:

https://github.com/aesara-devs/aesara/blob/main/aesara/tensor/nlinalg.py#L214

def perform(self, node, inputs, outputs):
    (x,) = inputs
    (z,) = outputs
    try:
        z[0] = np.asarray(np.linalg.det(x), dtype=x.dtype)
    except Exception:
        print("Failed to compute determinant", x)
        raise

Expectation would be to continue to defer to numpy for implementation, thus assuming last two dimensions have same length and broadcasting the determinant, but remove the dim == 2 assert. Fine, but does anyone remember why the assert exists for aesera cases specifically?

brandonwillard commented 2 years ago

Happy to take this up, but I want to sanity check. Current implementation in aesara is using np:

That would be great, but I believe this is one of many other pending Op implementations that would essentially amount to vectorize(op_fn, signature=...). In other words, our efforts need to be focused on https://github.com/aesara-devs/aesara/issues/695 (i.e. the Aesara version of numpy.vectorize), which would make the implementation of this Op trivial.

https://github.com/aesara-devs/aesara/blob/main/aesara/tensor/nlinalg.py#L214

def perform(self, node, inputs, outputs):
    (x,) = inputs
    (z,) = outputs
    try:
        z[0] = np.asarray(np.linalg.det(x), dtype=x.dtype)
    except Exception:
        print("Failed to compute determinant", x)
        raise

Expectation would be to continue to defer to numpy for implementation, thus assuming last two dimensions have same length and broadcasting the determinant, but remove the dim == 2 assert. Fine, but does anyone remember why the assert exists for aesera cases specifically?

Which assert? If you're referring to the one in Det.make_node, that's just some poorly implemented input validation. asserts like that should actually be raises with the appropriate Exception types.

ricardoV94 commented 2 years ago

I this case it seems like nothing new is needed other than removing the assert and checking if infer_shape and grad work, because numpy.linalg.det works with batched inputs out of the box: https://numpy.org/doc/stable/reference/generated/numpy.linalg.det.html

We should not need #695, because the limitation here was artificial to begin with.

brandonwillard commented 2 years ago

I this case it seems like nothing new is needed other than removing the assert and checking if infer_shape and grad work

That's largely what https://github.com/aesara-devs/aesara/issues/695 provides, as well as a single interface for transpilation, etc.

brandonwillard commented 2 years ago

We should not need #695, because the limitation here was artificial to begin with.

We absolutely do, if only to reduce the amount of redundant code, redundant testing time, time spent reviewing, and so on.

ricardoV94 commented 2 years ago

I don't get why. If the wrapped Op is not constrained to 2D inputs, why act like it is and use extra machinery to extend beyond it?

brandonwillard commented 2 years ago

I don't get why.

As I said above, a Blockwise base class would provide generalized Op.infer_shape and Op.[grad|Lop|Rop] implementations that can be thoroughly tested once in generality. It compartmentalizes the testing and review concerns by separating the "core" Op from its broadcasting/"vectorization"—among other things.

Furthermore, an implementation in this scenario would look something like the following:

class Det(Op):
    # Perhaps some `nfunc_spec`-like information here...

    def perform(self, ...):
        # The current `Det.perform`...

    # Potentially nothing else...

det = Blockwise(Det, signature=((("m", "n"),), ((),)))

I don't see how it could get much easier than that.

Also, we'll probably have something more general like nfunc_spec, which will allow one to specify the/a NumPy implementation to be used in Blockwise.perform.

Remember, we're basically talking about an extension to Elemwise, so imagine if we had no Elemwise and instead had a zoo of distinct Op implementations written (and reviewed) by more or less random developers, each upholding varying degrees of design and quality standards, feature requirements, and domain expertise. Now, imagine maintaining and extending that codebase over time. If you thought that making updates to Elemwise was difficult, imagine making those updates across all those redundant Ops that duplicate its functionality.

Also, consider how difficult it would be for external developers to implement custom Ops that fit into the Elemwise framework. If we had a Blockwise Op right now, but no Blockwise determinant in Aesara, we would be much better off, because it would be simple for an external developer to create their own vectorized determinant, but we don't, and external developers need to know a lot in order to implement their own vectorized Op.infer_shape and Op.grad methods.

If the wrapped Op is not constrained to 2D inputs, why act like it is and use extra machinery to extend beyond it?

Because there is no "extra machinery" if we do this the way I'm describing, but there is if we do it the way you seem to be advocating (i.e. creating redundant Ops with distinct Op.infer_shape and Op.grad implementations that all do overlapping subsets of the same things).

purna135 commented 2 years ago

Hello, @guyrt. Thank you so much for your interest; I am almost finished implementing det to work with data other than 2D. The only thing I'm waiting for is PR #808, which I need to calculate the grad of det Op.

purna135 commented 2 years ago

could someone assign this issue to me?

brandonwillard commented 2 years ago

could someone assign this issue to me?

@purna135, take a look at https://github.com/aesara-devs/aesara/issues/1065#issuecomment-1204740044. From an administrative standpoint, this project needs to focus its efforts on https://github.com/aesara-devs/aesara/issues/695, since that would close—or directly assist with closing—this issue, https://github.com/aesara-devs/aesara/issues/488, https://github.com/aesara-devs/aesara/issues/791, and https://github.com/aesara-devs/aesara/issues/612. It would also affect the formulation of solutions to https://github.com/aesara-devs/aesara/issues/1089, https://github.com/aesara-devs/aesara/issues/1094, https://github.com/aesara-devs/aesara/issues/933, https://github.com/aesara-devs/aesara/issues/617, https://github.com/aesara-devs/aesara/issues/770, https://github.com/aesara-devs/aesara/issues/547, https://github.com/aesara-devs/aesara/issues/712, and numerous rewrite-related simplifications and issues.

Sayam753 commented 2 years ago

Hi @brandonwillard Purna's GSoC project is about expanding the support of multivariate distributions in PyMC to work with batched data. The project's scope currently includes fixing the Ops which multivariate distributions depend on. These are just a handful of Ops in Aesara which need a fix in their infer_shape and grad/L_Op/R_Op methods for the project. These Ops are solve, cholesky, det, eigh and matrix_inverse.

Working on #695 is currently out-of-scope. I indeed see long-term benefits if efforts are redirected towards https://github.com/aesara-devs/aesara/issues/695

If you like, we can redefine the scope of the project to first work on https://github.com/aesara-devs/aesara/issues/695 and then work on multivariate distributions at PyMC side.

brandonwillard commented 2 years ago

On a related note, see the note in NumPy's documentation for numpy.class.__array_ufunc__. It looks like they are/were in the process of making changes similar to the ones I've been talking about wrt. Blockwise.