Open YifanShenSZ opened 1 year ago
In the proposed fix to issue #1303, we are trying to gather symbols
else:
# Create the new_shape and input_shape to support dynamic sizes.
xshape = mb.shape(x=x)
for i, v in enumerate(x.shape[2:]):
if is_symbolic(v):
si = mb.gather(x=xshape, indices=i+2, axis=0)
new_shape.append(si)
input_shape[i+2] = si
else:
new_shape.append(v)
In reproducing issue #1921, torch.listconstruct
is found to be culprit
(Pdb) print(context.torch_graph)
graph(
%x : Tensor(1, 192, RangeDim(lower_bound=2, upper_bound=1024, default=10, symbol="is0"), 'None'),
):
%2 = constant[value=2]()
%3 = size[](%x, %2)
%pad_length = numtotensor[](%3)
%5 = int[](%pad_length)
%6 = int[](%pad_length)
%7 = constant[value=0]()
%8 = constant[value=0]()
%9 = listconstruct[](%7, %8, %6, %5)
%10 = constant[value=constant]()
%11 = constant[]()
%12 = pad[](%x, %9, %10, %11)
return (%12)
(Pdb) print(node)
%12 = pad[](%x, %9, %10, %11)
The pad
, i.e. variable %9
, is the output of torch.listconstruct
, which consists of 2 consts and 2 symbols
Yes, for the #1921 case (for op 'pad'):
# pseudo code
@register_torch_op
def listconstruct():
if constant shape: # no bug case
return mb.const(val=[static shapes, ...])
else: # #1921 case
return [mixed of static shapes, and dynamic shapes, ...] # which failed to parse in op 'pad'
If we want to fix it referring to the solution in #1303 ( #1922 ), we basically want:
# pseudo code
def listconstruct():
if constant shape: # no bug case
return mb.const(val=[static shapes, ...])
else if_match_#1921_case():
static_shapes = [...] # extract static shapes from inputs
sliced_dynamic_shapes = mb.slice_by_size(x=, begin=[], size=) # extract dynamic shape syms from inputs
return mb.concat(values=[static_shapes, sliced_dynamic_shapes]) # mb.concat() the static and dynamic parts of the shape
To be specific, in #1921 case:
>>> context.torch_graph
graph(
%x : Tensor(1, 192, RangeDim(lower_bound=2, upper_bound=1024, default=10, symbol="is0"), 'None'),
):
%2 = constant[value=2]()
%3 = size[](%x, %2)
%pad_length = numtotensor[](%3)
%5 = int[](%pad_length)
%6 = int[](%pad_length)
%7 = constant[value=0]()
%8 = constant[value=0]()
%9 = listconstruct[](%7, %8, %6, %5)
%10 = constant[value=constant]()
%11 = constant[]()
%12 = pad[](%x, %9, %10, %11)
return (%12)
>>> node.name # I pdb in def _array_construct()
'9'
>>> node.inputs
['7', '8', '6', '5']
>>> context['7'].val # static shapes
0
>>> context['8'].val # static shapes
0
>>> context['6'].op # dynamic shapes
%6: (int32)(Scalar) = cast(x=%gather_0, dtype="int32", name="6")
>>> context['5'].op # dynamic shapes
%5: (int32)(Scalar) = cast(x=%gather_0, dtype="int32", name="5")
if we hard coded the solution for it:
def _array_construct(context, node, array_type):
...
else:
# context.add(array_type(inputs), node.name)
static_7_8 = [context['7'].val, context['8'].val]
sliced_dynamic_6 = mb.slice_by_size(x=mb.shape(x=context['x']), begin=[2], size=[1])
sliced_dynamic_5 = mb.slice_by_size(x=mb.shape(x=context['x']), begin=[2], size=[1])
context.add(mb.concat(values=[static_7_8, sliced_dynamic_6, sliced_dynamic_5], axis=0), node.name)
and modify op pad
register so that it supports padding with sym vals:
@register_torch_op(torch_alias=['constant_pad_nd'])
def pad(context, node):
...
if pad.val is not None:
...
else:
missing_dims = (x.rank * 2 - pad.shape[0]) // 2
pad = mb.concat(values=[pad, [0, 0] * missing_dims], axis=0)
pad = mb.reverse(x=pad, axes=[0])
However, I have some doubts before proposing a general fix for it.
in the above case, the padding value is hard-coded:
My question:
How can we relate from node '9'.inputs all the way to input x symbolic value, with data structure:
context
, node
, and <coremltools.converters.mil.mil.var.Var object>
? (context['9'].inputs[2] for example)Edit: I've noticed context.torch_graph.nodes
and it answers my question above. WIP...
Hi @xorange, thanks for looking into this issue! About relating a symbol to input symbols, you should be able to simply compare if those symbols are the same: we propagate symbols using sympy
As of the fix, I have several thoughts that might be easier:
pad
-specific fix: In pad
, given the constructed list, is it possible to use something like mb.stack
, mb.concat
, or mb.gather
to construct a tensor?_array_construct
output signature, it would break backward compatibility 😞 All functions that rely on "list of symbols" rather than tensor would breakQuoting from #2050:
- This fix addes another branch, only targets for an op gather, and what it gathers from is not a name in context.torch_graph.nodes, i.e. the net inputs.
Upon #2037, and another net structure on my hand that shared a similar root cause, it is clear that only targetting op gather
is not enough to provide a generalized fix.
for example,
>>> context.torch_graph
graph(
%x.1 : Tensor(1, 3, RangeDim(lower_bound=300, upper_bound=400, default=300, symbol="is0"), RangeDim(lower_bound=300, upper_bound=400, default=300, symbol="is1"), 'None'),
...
)
%input.1 = _convolution[](%x.1, %model.features.conv1.0.weight, %8, %31, %32, %33, %12, %34, %11, %12, %12, %14, %14)
or
>>> context.torch_graph
graph(
%x.1 : Tensor(1, RangeDim(lower_bound=5, upper_bound=512, default=275, symbol="is0"), 'None'),
%x_mask : Tensor(1, 1, RangeDim(lower_bound=5, upper_bound=512, default=275, symbol="is1"), 'None'),
...
)
%9 = embedding[](%emb.weight, %x.1, %7, %6, %6)
%x.3 = mul[](%9, %10)
%x.5 = transpose[](%x.3, %12, %13)
%x.7 = mul[](%x.5, %x_mask)
%input.1 = mul[](%x.7, %x_mask)
Hi @xorange, thanks for looking into this issue! About relating a symbol to input symbols, you should be able to simply compare if those symbols are the same: we propagate symbols using
sympy
Thanks for reply ! I'll look into it.
As of the fix, I have several thoughts that might be easier:
pad
-specific fix: Inpad
, given the constructed list, is it possible to use something likemb.stack
,mb.concat
, ormb.gather
to construct a tensor?
Yes this should be a cleaner way for pad
.
- general fix: The ultimate problem is, if we change
_array_construct
output signature, it would break backward compatibility 😞 All functions that rely on "list of symbols" rather than tensor would break
I think we both agree that a generalized fix is what we want here... Because I've already come across several cases that pad
is not the culprit.
Could you share some functions that rely on "list of symbols" for me to design ? Let me see if I can cover for those, or learn the current design better (because clearly I'm missing something here).
Could you share some functions that rely on "list of symbols" for me to design ? Let me see if I can cover for those, or learn the current design better (because clearly I'm missing something here).
Unfortunately I cannot tell from top of my mind 😞 We could try to modify _array_construct
, then pytest --pyargs coremltools.converters.mil.frontend.torch
to see what gets broken
Any progress on this issue?
Any progress on this issue?
None on my end. Having trouble coordinating between work and spare time for this, and it requires a lot to digest the whole design. No progress will be made from me at least before 2024 Q4 sry.
This is the root cause to many issues. When symbolic shape is involved in
torch.listconstruct
, instead of a CoreML tensor, we simply return the list as isAppendix 1: Issues Sharing the Same Root Cause
Appendix 2: Ops Impacted by the Root Cause
torch.GroupNorm
torch.pad
torch.index_put