Closed scopello closed 1 year ago
Seems to work fine on A100, but not H100.
ah nice, yea that seems like an einops / pytorch specific error, but not entirely sure
what is your use-case btw? that's a really interesting network
oh, are you doing two towers architecture?
Thanks! This is for model that requires encoders for 2 different modalities. Btw, would you expect any significant speedup by using torch.compile if flash attention is enabled?
Hi @lucidrains,
I am trying to use torch.compile() with a model that wraps two x-transformer Encoders. When I run the following minimal example:
I get error:
TorchRuntimeError: Failed running call_function (*(FakeTensor(..., device='cuda:0', size=(1, s0, 128), grad_fn=), 'b n (h d) -> b h n d'), **{'h': 2}): unhashable type: non-singleton SymInt
Which comes from: https://github.com/lucidrains/x-transformers/blob/2a0ec67fbdad18d2bd5f8bf3d9bc20e705a58a6b/x_transformers/x_transformers.py#L801Surprisingly, the model compiles successfully if I set
seq_len_2 = seq_len_1
, but I don't know why.I am using einops 0.7.0rc1 and pytorch 2.1.0
Thanks!