apple / coremltools

Core ML tools contain supporting tools for Core ML model conversion, editing, and validation.
https://coremltools.readme.io
BSD 3-Clause "New" or "Revised" License
4.37k stars 631 forks source link

[PyTorch] torch.listconstruct causing issue for other ops #1926

Open YifanShenSZ opened 1 year ago

YifanShenSZ commented 1 year ago

This is the root cause to many issues. When symbolic shape is involved in torch.listconstruct, instead of a CoreML tensor, we simply return the list as is

def _array_construct(context, node, array_type):

    ...

    else:
        # If at least one input to the construct op is non-const, collect
        # the inputs and add them directly to the context. Ops that use this
        # node's output will take the list directly as input.
        context.add(array_type(inputs), node.name)

Appendix 1: Issues Sharing the Same Root Cause

Appendix 2: Ops Impacted by the Root Cause

YifanShenSZ commented 1 year ago

In the proposed fix to issue #1303, we are trying to gather symbols

    else:
        # Create the new_shape and input_shape to support dynamic sizes.
        xshape = mb.shape(x=x)
        for i, v in enumerate(x.shape[2:]):
            if is_symbolic(v):
                si = mb.gather(x=xshape, indices=i+2, axis=0)
                new_shape.append(si)
                input_shape[i+2] = si
            else:
                new_shape.append(v)
YifanShenSZ commented 1 year ago

In reproducing issue #1921, torch.listconstruct is found to be culprit

(Pdb) print(context.torch_graph)
graph(
    %x : Tensor(1, 192, RangeDim(lower_bound=2, upper_bound=1024, default=10, symbol="is0"), 'None'),
):
  %2 = constant[value=2]()
  %3 = size[](%x, %2)
  %pad_length = numtotensor[](%3)
  %5 = int[](%pad_length)
  %6 = int[](%pad_length)
  %7 = constant[value=0]()
  %8 = constant[value=0]()
  %9 = listconstruct[](%7, %8, %6, %5)
  %10 = constant[value=constant]()
  %11 = constant[]()
  %12 = pad[](%x, %9, %10, %11)
return (%12)
(Pdb) print(node)
  %12 = pad[](%x, %9, %10, %11)

The pad, i.e. variable %9, is the output of torch.listconstruct, which consists of 2 consts and 2 symbols

xorange commented 11 months ago

Yes, for the #1921 case (for op 'pad'):

# pseudo code
@register_torch_op
def listconstruct():
    if constant shape:        # no bug case
        return mb.const(val=[static shapes, ...])

    else:                              # #1921 case
        return [mixed of static shapes, and dynamic shapes, ...]    # which failed to parse in op 'pad'

If we want to fix it referring to the solution in #1303 ( #1922 ), we basically want:

# pseudo code
def listconstruct():
    if constant shape:        # no bug case
        return mb.const(val=[static shapes, ...])

    else if_match_#1921_case():
        static_shapes = [...]    # extract static shapes from inputs
        sliced_dynamic_shapes = mb.slice_by_size(x=, begin=[], size=)    # extract dynamic shape syms from inputs
        return mb.concat(values=[static_shapes, sliced_dynamic_shapes])    # mb.concat() the static and dynamic parts of the shape

To be specific, in #1921 case:

>>> context.torch_graph
graph(
    %x : Tensor(1, 192, RangeDim(lower_bound=2, upper_bound=1024, default=10, symbol="is0"), 'None'),
):
  %2 = constant[value=2]()
  %3 = size[](%x, %2)
  %pad_length = numtotensor[](%3)
  %5 = int[](%pad_length)
  %6 = int[](%pad_length)
  %7 = constant[value=0]()
  %8 = constant[value=0]()
  %9 = listconstruct[](%7, %8, %6, %5)
  %10 = constant[value=constant]()
  %11 = constant[]()
  %12 = pad[](%x, %9, %10, %11)
return (%12)
>>> node.name               # I pdb in def _array_construct()
'9'

>>> node.inputs
['7', '8', '6', '5']
>>> context['7'].val        # static shapes
0
>>> context['8'].val        # static shapes
0
>>> context['6'].op         # dynamic shapes
  %6: (int32)(Scalar) = cast(x=%gather_0, dtype="int32", name="6")

>>> context['5'].op         # dynamic shapes
  %5: (int32)(Scalar) = cast(x=%gather_0, dtype="int32", name="5")

if we hard coded the solution for it:

def _array_construct(context, node, array_type):

    ...

    else:
        # context.add(array_type(inputs), node.name)

        static_7_8 = [context['7'].val, context['8'].val]
        sliced_dynamic_6 = mb.slice_by_size(x=mb.shape(x=context['x']), begin=[2], size=[1])
        sliced_dynamic_5 = mb.slice_by_size(x=mb.shape(x=context['x']), begin=[2], size=[1])
        context.add(mb.concat(values=[static_7_8, sliced_dynamic_6, sliced_dynamic_5], axis=0), node.name)

and modify op pad register so that it supports padding with sym vals:

@register_torch_op(torch_alias=['constant_pad_nd'])
def pad(context, node):

    ...

    if pad.val is not None:
        ...

    else:
        missing_dims = (x.rank * 2 - pad.shape[0]) // 2
        pad = mb.concat(values=[pad, [0, 0] * missing_dims], axis=0)
        pad = mb.reverse(x=pad, axes=[0])

1921 is confirmed fixed (temporarily)

xorange commented 11 months ago

However, I have some doubts before proposing a general fix for it.

in the above case, the padding value is hard-coded:

My question:

How can we relate from node '9'.inputs all the way to input x symbolic value, with data structure:

Edit: I've noticed context.torch_graph.nodes and it answers my question above. WIP...

YifanShenSZ commented 11 months ago

Hi @xorange, thanks for looking into this issue! About relating a symbol to input symbols, you should be able to simply compare if those symbols are the same: we propagate symbols using sympy

As of the fix, I have several thoughts that might be easier:

  1. pad-specific fix: In pad, given the constructed list, is it possible to use something like mb.stack, mb.concat, or mb.gather to construct a tensor?
  2. general fix: The ultimate problem is, if we change _array_construct output signature, it would break backward compatibility 😞 All functions that rely on "list of symbols" rather than tensor would break
xorange commented 11 months ago

Quoting from #2050:

  1. This fix addes another branch, only targets for an op gather, and what it gathers from is not a name in context.torch_graph.nodes, i.e. the net inputs.

Upon #2037, and another net structure on my hand that shared a similar root cause, it is clear that only targetting op gather is not enough to provide a generalized fix.

for example,

>>> context.torch_graph
graph(
    %x.1 : Tensor(1, 3, RangeDim(lower_bound=300, upper_bound=400, default=300, symbol="is0"), RangeDim(lower_bound=300, upper_bound=400, default=300, symbol="is1"), 'None'),
    ...
)
%input.1 = _convolution[](%x.1, %model.features.conv1.0.weight, %8, %31, %32, %33, %12, %34, %11, %12, %12, %14, %14)

or

>>> context.torch_graph
graph(
    %x.1 : Tensor(1, RangeDim(lower_bound=5, upper_bound=512, default=275, symbol="is0"), 'None'),
    %x_mask : Tensor(1, 1, RangeDim(lower_bound=5, upper_bound=512, default=275, symbol="is1"), 'None'),
    ...
)
%9 = embedding[](%emb.weight, %x.1, %7, %6, %6)
%x.3 = mul[](%9, %10)
%x.5 = transpose[](%x.3, %12, %13)
%x.7 = mul[](%x.5, %x_mask)
%input.1 = mul[](%x.7, %x_mask)
xorange commented 11 months ago

Hi @xorange, thanks for looking into this issue! About relating a symbol to input symbols, you should be able to simply compare if those symbols are the same: we propagate symbols using sympy

Thanks for reply ! I'll look into it.

As of the fix, I have several thoughts that might be easier:

  1. pad-specific fix: In pad, given the constructed list, is it possible to use something like mb.stack, mb.concat, or mb.gather to construct a tensor?

Yes this should be a cleaner way for pad.

  1. general fix: The ultimate problem is, if we change _array_construct output signature, it would break backward compatibility 😞 All functions that rely on "list of symbols" rather than tensor would break

I think we both agree that a generalized fix is what we want here... Because I've already come across several cases that pad is not the culprit.

Could you share some functions that rely on "list of symbols" for me to design ? Let me see if I can cover for those, or learn the current design better (because clearly I'm missing something here).

YifanShenSZ commented 10 months ago

Could you share some functions that rely on "list of symbols" for me to design ? Let me see if I can cover for those, or learn the current design better (because clearly I'm missing something here).

Unfortunately I cannot tell from top of my mind 😞 We could try to modify _array_construct, then pytest --pyargs coremltools.converters.mil.frontend.torch to see what gets broken

kdonbekci commented 5 months ago

Any progress on this issue?

xorange commented 5 months ago

Any progress on this issue?

None on my end. Having trouble coordinating between work and spare time for this, and it requires a lot to digest the whole design. No progress will be made from me at least before 2024 Q4 sry.