ezyang / SMT-LIB-benchmarks-pytorch-shapes

SMT-LIB benchmarks for shape computations from deep learning models in PyTorch
18 stars 0 forks source link

General questions to shape compute expressions #1

Closed ganler closed 1 year ago

ganler commented 1 year ago

Thanks for the awesome work! I want to ask a few clarification questions regarding the shape compute expressions:

  1. Just curious what is the overall method for collecting those formulas? I guess it is that the PyTorch compiler implemented the shape propagation function symbolically for each operator and you converted the shape compute expressions from PyTorch's internal symbolic representation to smt-lib expression?
  2. The full list of arithmetic operations includes some interesting floating-point shape computations such as pow, truediv, sqrt, etc. Could you kindly indicate examples of operators where those operations are used?
  3. Some operators are input-value or runtime dependent -- i.e., the values/elements of the input tensor will affect the output shape such that shape computation cannot be done from the abstract domain (i.e., with only tensor type & constant information). Those examples include Nonzero, Gather, NMS, etc. Just curious if and how those operators are modeled in shape computation (I guess in most cases the compilers would assume some output shape dimensions are unknown that will be represented with some new symbols?).

For the open questions of strengthening preconditions, I think it would be very interesting and helpful for model serving (I am answering the question of "is it necessary"):

  1. Security: we can reject invalid input shapes to avoid DoS attacks since many reported compiler/runtime crashes are often triggered by combinations of invalid inputs and API attributes (See TensorFlow security advisor).
  2. Dynamic inference: in video analytics, different resolutions of images can be used for doing trade-offs between performance and accuracy -- we cannot use arbitrary resolutions as they could be invalid. As a result, we need to infer a set of valid input shapes and then sample the valid ones.

But I am curious what could be the key challenges for achieving this. If we just want to reject invalid input shapes/dtypes ahead of time, we can just use shape prop functions to propagate the shapes and perform assertion for each operator (assume that they are implemented symbolically), which can be done at the operator level. In fact, we had a tool called NNSmith (paper) which uses these two properties for generating random yet valid DNNs for fuzzing DL compilers. The main challenge we found is that we used to manually specify shape prop & input constraints for each operator because there is no systematic and general way for extracting such information directly from open-source libraries -- while manually doing so is hard for scaling to a large number of operators. Therefore, the way this benchmark is generated looks very interesting to me!

Sorry for the inconvenience of putting so a bunch of questions together. Really appreciate the great work!

ezyang commented 1 year ago

Just curious what is the overall method for collecting those formulas? I guess it is that the PyTorch compiler implemented the shape propagation function symbolically for each operator and you converted the shape compute expressions from PyTorch's internal symbolic representation to smt-lib expression?

Yep! We did the laborious work of writing shape propagation functions for all the operators (sometimes manually, sometimes by decomposing the operator into other operators), and then ran the compiler stack to get all of the shape compute. One other thing to note is that the computation here is traced: we flattened away all the conditionals.

The full list of arithmetic operations includes some interesting floating-point shape computations such as pow, truediv, sqrt, etc. Could you kindly indicate examples of operators where those operations are used?

Hmm, I guess in the next revision I should include source pointers for the SMT expressions :) Often these arithmetic operations are from the models themselves, not operators.

For example the first sqrt in jx_nest_base is from here:

@register_notrace_function  # reason: int receives Proxy
def deblockify(x, block_size: int):
    """blocks to image
    Args:
        x (Tensor): with shape (B, T, N, C) where T is number of blocks and N is sequence size per block
        block_size (int): edge length of a single square block in units of desired H, W
    """
    B, T, _, C = x.shape
    grid_size = int(math.sqrt(T))
    height = width = grid_size * block_size
    x = x.reshape(B, grid_size, grid_size, block_size, block_size, C)
    x = x.transpose(2, 3).reshape(B, height, width, C)
    return x  # (B, H, W, C)

Some of the truedivs look like they're coming from batchnorm:

        if running_var is not None:
            n = input.numel() / input.shape[1]
            # This doesn't strictly match eager's numerics, which accumulates var sum and then directly applies the correction
            # But... that would require re-implementing var here, for negligible numerics gain on a tensor whose
            # numerics probably don't matter.
            unbiased_var = torch.var(input, reduction_dims, unbiased=False) * (
                n / (n - 1)
            )

actually, I probably should not have put these in the dataset, since we end up not using this for a shape compute, it just gets fed into some tensor computation. If you DCE the expressions this truediv probably will evaporate. Some other occurrences of truediv are for arange, which has a funny computation for determining output size:

    shape = (math.ceil((end - start) / step),)

(and yes, end/start/step can be float here, this leads to funny bugs when the end is too big).

It doesn't look like pow shows up in the dataset anywhere.

Some operators are input-value or runtime dependent -- i.e., the values/elements of the input tensor will affect the output shape such that shape computation cannot be done from the abstract domain (i.e., with only tensor type & constant information). Those examples include Nonzero, Gather, NMS, etc. Just curious if and how those operators are modeled in shape computation (I guess in most cases the compilers would assume some output shape dimensions are unknown that will be represented with some new symbols?).

Today, we graph break on these, so that the subsequent graph can take these dynamically sized tensors as inputs and generate a kernel that can handle differing sizes. We're also looking to be able to trace these into a single graph, and so that will involve allocating a fresh symbol. However, the challenge here is at compile time we don't know what the actual value of the symbol is, so we can no longer evaluate boolean expressions on it (which means, in practice, you'll error out pretty quick.) Some more discussion about it at https://github.com/pytorch/pytorch/pull/90624

If we just want to reject invalid input shapes/dtypes ahead of time, we can just use shape prop functions to propagate the shapes and perform assertion for each operator

Yes, if you just care about soundness, this is really all you need to do (besides the coverage problem, but we're just solving that by brute force.) However, there's also a problem where you want to simplify all of the tests you want to do. It's not really academically interesting, since usual compiler / JIT literature tells you how to do it, but we still have to do it (and maybe there's some domain specific insights that can apply here.)

ganler commented 1 year ago

Thanks for the detailed reference! The floating-point example looks interesting! though I think we can still use ((end - start + step - 1) / step),) to mimic the integer ceiling.

However, there's also a problem where you want to simplify all of the tests you want to do.

Does that mean we want to obtain a minimal but still complete set of predicates for checking the input shapes -- which not only minimizes the runtime checking overhead but makes simplified test cases for each case? If so, I think one way (also as you mentioned in the README) to do it is to obtain a set of predicates P, and we can try to minimize P by:

for p in P: # let's ignore the modify-by-iter error for clarity.
    if z3.prove(z3.And(*P) == z3.And(*[pp for pp in P if pp is not p)):
        P.remove(p)

But yeah it could be costly when P is large...

ezyang commented 1 year ago

We don't even want to be calling out to z3 at all! Hoping there is a miniature solver we can add that will do this.

ganler commented 1 year ago

Ah, I see. Not sure if sympy has a built-in for it or if it can be implemented with the "Implies" function...

Thanks for the comments! A lot was learned!