Open Datamance opened 1 year ago
OK, looking at the graph output, it looks like index_put
is first called here in this method called by forward
inLongformerSelfAttention
:
def _concat_with_global_key_attn_probs(
self,
key_vectors,
query_vectors,
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
):
batch_size = key_vectors.shape[0]
# create only global key vectors
key_vectors_only_global = key_vectors.new_zeros(
batch_size, max_num_global_attn_indices, self.num_heads, self.head_dim
)
key_vectors_only_global[is_local_index_global_attn_nonzero] = key_vectors[is_index_global_attn_nonzero]
in a conditional block that checks if the pass is for global attention:
# compute local attention probs from global attention keys and contact over window dim
if is_global_attn:
# compute global attn indices required through out forward fn
(
max_num_global_attn_indices,
is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero,
) = self._get_global_attn_indices(is_index_global_attn)
# calculate global attn probs from global key
global_key_attn_scores = self._concat_with_global_key_attn_probs(
query_vectors=query_vectors,
key_vectors=key_vectors,
max_num_global_attn_indices=max_num_global_attn_indices,
is_index_global_attn_nonzero=is_index_global_attn_nonzero,
is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
)
So somehow, it looks like key_vectors
is the wrong shape. Or, more precisely - key_vectors[is_index_global_attn_nonzero]
is short one dimension. I also ruled out batch size being the complicating factor by changing all the input shapes to (4, 4096)
which is closer to what I'd do in real training.
I will have to dig in more to see why the mismatch happens, but again I am all ears for different approaches/solutions to this issue.
Hey @Datamance, the error you got
File "/Users/ricorodriguez/GradSchool/Classes/Fall2023/MachineLearning/SecondProject/env/lib/python3.11/site-packages/coremltools/converters/mil/mil/ops/defs/iOS17/scatter_gather.py", line 228, in type_inference
result = super().type_inference()
^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/ricorodriguez/GradSchool/Classes/Fall2023/MachineLearning/SecondProject/env/lib/python3.11/site-packages/coremltools/converters/mil/mil/ops/defs/iOS15/scatter_gather.py", line 546, in type_inference
assert is_compatible_symbolic_vector(
AssertionError
Is probably due to some dynamic shape involved in your model. 3 observations supporting my guess:
new_ones
is creating a new tensor full of 1, given the shape of another tensor. It is possible that this "another tensor" has a dynamic shape.Out[5]: (is5, 12, 64)
, where is5
is "5th integer symbol"I would suggest to try a fixed shape model first, to rule out symbol noise and make sure things convert.
Here is my script:
I think tracing is the right call here because my input will always be collated and padded to the correct shape. So I'm not worried about any of the warnings (which mostly come from assert statements anyway, so they don't really affect control flow).
You can see above that I circumvented initial errors about the
as_strided
op not being implemented by lying to the config about exporting to ONNX. That got me a little farther, but then I hit another wall atnew_ones
. I tried to circumvent this yet again by registering another op like so:which is basically a slightly modified copy of
new_full
. That seems to work (although I'm curious to get feedback on whether this is the right approach?), but then I get this error:I stepped through with a debugger and it seems like the
index_put
andscatter_nd
ops are the culprit here. When the builder tries to addscatter_nd
, it does some kind of type inference check and fails an assertion here:Checking the shapes in the debugger shows
Immediate question: did I do something wrong, or is this a bug?
Follow-up: what is the recommendation to proceed here? For context, I would ideally like to do training-time quantization (probably Linear 8-bit) on this Longformer model. Will I keep running into unimplemented ops even if I get past this issue? Should I try something else entirely?
Thanks!