Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.2k
stars
80
forks
source link
Dynamic shapes: use t.size(d) output as input to .view/.reshape #939
LitGPT's CausalSelfAttention.forward contains the following code, which uses sizes (in particular batch size and sequence length in inputs). It would be super useful if dynamic shape support meant that we could get around re-acquiring the Python for varying batch size (so we would need support for .size returning tuples on NumberProxy, possible very simple computation on them (but not here, I think), and then feeding them into .view and .reshape.
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
qkv = self.attn(x)
# assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
q_per_kv = self.config.n_head // self.config.n_query_groups
total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value
qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size)
qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs)
# split batched computation into three
q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
...
q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs)
k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs)
v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs)
...
y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side
# output projection
return self.proj(y)
LitGPT's CausalSelfAttention.forward contains the following code, which uses sizes (in particular batch size and sequence length in inputs). It would be super useful if dynamic shape support meant that we could get around re-acquiring the Python for varying batch size (so we would need support for .size returning tuples on NumberProxy, possible very simple computation on them (but not here, I think), and then feeding them into .view and .reshape.
@jjsjann123