lucidrains / x-transformers

A simple but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.42k stars 377 forks source link

ONNX export failed #212

Open pfeatherstone opened 8 months ago

pfeatherstone commented 8 months ago

Here is a repro:

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.lm = TransformerWrapper (
            num_tokens          = 256,
            max_seq_len         = 0,
            num_memory_tokens   = 20,
            attn_layers = Decoder (
                dim             = 512,
                depth           = 1,
                heads           = 4,
                rotary_pos_emb  = True,
                shift_tokens    = 1,
                attn_flash      = True,
                attn_onnxable   = True,
                use_scalenorm   = True,
                sandwich_norm   = True
            )
        )
    def forward(self, x, mask):
        return self.lm(x, mask=mask, return_embeddings=True)

net = Model()
x = torch.randint(0, 256, size=(4, 1024))
m = x < 128
x = net(x, m)
print('Normal inferrence ok')

torch.onnx.export(net, (x,m), '/tmp/model.onnx', opset_version=17, 
                  input_names=['x', 'mask'],
                  output_names=['embeddings'],
                  dynamic_axes={'x'             : {0: 'B', 1: 'T'},
                                'mask'          : {0: 'B', 1: 'T'},
                                'embeddings'    : {0: 'B', 1: 'T'}})
print('Onnx export ok')

The export fails with message: "b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = *x.shape, x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient ValueError: too many values to unpack (expected 6)"

pfeatherstone commented 8 months ago

Then if i change:

*x.shape,

to

x.shape[0], x.shape[1]

I get another error:

x_transformers.py", line 1238, in forward
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length)
return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
RuntimeError: einsum(): the number of subscripts in the equation (1) does not match the number of dimensions (0) for operand 0 and no ellipsis was given
pfeatherstone commented 8 months ago

It would seem that during normal inference max_rotary_emb_length is an int, during JIT tracing or ONNX export it's a 0-dimensional tensor.

EDIT: It looks like generally something like x.shape[0] is a normal int in normal pytorch, while in tracing, scripting or ONNX export, it's a torch.Tensor. they must have changed the behaviour recently.

pfeatherstone commented 8 months ago

Changing:

if isinstance(seq_arange_or_len, int):

at line 432 to

if isinstance(seq_arange_or_len, int) or seq_arange_or_len.dim() == 0:

seems to resolve everything but i dunno, this seems like a hack.

lucidrains commented 8 months ago

@pfeatherstone ah yea, think i may have a solution

threw in some fixes (but by no means for all configurations)

let me know if that works

lucidrains commented 8 months ago

@pfeatherstone you set max_seq_len to 0 to turn off absolute positional embedding? may not work as you intended (but should be fixed)

pfeatherstone commented 8 months ago

I now get the following error during export:

File ".../x_transformers/x_transformers.py", line 577, in forward
    segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
  File ".../x_transformers/x_transformers.py", line 577, in <lambda>
    segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
  File ".../x_transformers/x_transformers.py", line 560, in shift
    t = t.masked_fill(~mask[..., None], 0.)
RuntimeError: The size of tensor a (1044) must match the size of tensor b (524308) at non-singleton dimension 1
pfeatherstone commented 8 months ago

@lucidrains sorry to bother again. But it would be really cool to get this working with ONNX. At some point I might submit a PR which adds CI/CD. Some unit tests would go a long way

lucidrains commented 8 months ago

@pfeatherstone can you try it without shift tokens?

lucidrains commented 8 months ago

@pfeatherstone yea, i know some others have already gotten onnx to work in production, so it definitely works for some configurations, just not all. the repository at this point prioritizes simplicity; it is not worth bending over backwards to make onnx work for all settings.

pfeatherstone commented 8 months ago

@lucidrains No it didn't work either

lucidrains commented 8 months ago

ah alright, i'll have to circle back to this some other time

pfeatherstone commented 8 months ago

the repository at this point prioritizes simplicity; it is not worth bending over backwards to make onnx work for all settings.

OK cool. to be honest, once i've nailed down the configurations i want, i might write from scratch keeping exactly what i need, then it will probably be easier to debug the onnx export.

lucidrains commented 8 months ago

@pfeatherstone yes exactly, that is how i intended it to be

pfeatherstone commented 8 months ago

I've tried so many configurations and it turns out i only really need: