Open pfeatherstone opened 8 months ago
Then if i change:
*x.shape,
to
x.shape[0], x.shape[1]
I get another error:
x_transformers.py", line 1238, in forward
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length)
return _VF.einsum(equation, operands) # type: ignore[attr-defined]
RuntimeError: einsum(): the number of subscripts in the equation (1) does not match the number of dimensions (0) for operand 0 and no ellipsis was given
It would seem that during normal inference max_rotary_emb_length
is an int
, during JIT tracing or ONNX export it's a 0-dimensional tensor.
EDIT:
It looks like generally something like x.shape[0]
is a normal int
in normal pytorch, while in tracing, scripting or ONNX export, it's a torch.Tensor
. they must have changed the behaviour recently.
Changing:
if isinstance(seq_arange_or_len, int):
at line 432 to
if isinstance(seq_arange_or_len, int) or seq_arange_or_len.dim() == 0:
seems to resolve everything but i dunno, this seems like a hack.
@pfeatherstone ah yea, think i may have a solution
threw in some fixes (but by no means for all configurations)
let me know if that works
@pfeatherstone you set max_seq_len
to 0 to turn off absolute positional embedding? may not work as you intended (but should be fixed)
I now get the following error during export:
File ".../x_transformers/x_transformers.py", line 577, in forward
segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
File ".../x_transformers/x_transformers.py", line 577, in <lambda>
segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
File ".../x_transformers/x_transformers.py", line 560, in shift
t = t.masked_fill(~mask[..., None], 0.)
RuntimeError: The size of tensor a (1044) must match the size of tensor b (524308) at non-singleton dimension 1
@lucidrains sorry to bother again. But it would be really cool to get this working with ONNX. At some point I might submit a PR which adds CI/CD. Some unit tests would go a long way
@pfeatherstone can you try it without shift tokens?
@pfeatherstone yea, i know some others have already gotten onnx to work in production, so it definitely works for some configurations, just not all. the repository at this point prioritizes simplicity; it is not worth bending over backwards to make onnx work for all settings.
@lucidrains No it didn't work either
ah alright, i'll have to circle back to this some other time
the repository at this point prioritizes simplicity; it is not worth bending over backwards to make onnx work for all settings.
OK cool. to be honest, once i've nailed down the configurations i want, i might write from scratch keeping exactly what i need, then it will probably be easier to debug the onnx export.
@pfeatherstone yes exactly, that is how i intended it to be
I've tried so many configurations and it turns out i only really need:
Here is a repro:
The export fails with message: "b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = *x.shape, x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient ValueError: too many values to unpack (expected 6)"