Open NLPwoods opened 10 months ago
seqlen_sum, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) xq = xq.view(seqlen_sum, self.n_heads, self.args.head_dim) xk = xk.view(seqlen_sum, self.n_kv_heads, self.args.head_dim) xv = xv.view(seqlen_sum, self.n_kv_heads, self.args.head_dim)