mistralai / mistral-inference

Official inference library for Mistral models
https://mistral.ai/
Apache License 2.0
9.64k stars 850 forks source link

How to process batch input in mistral-src/model.py ? #78

Open NLPwoods opened 10 months ago

NLPwoods commented 10 months ago
    seqlen_sum, _ = x.shape

    xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
    xq = xq.view(seqlen_sum, self.n_heads, self.args.head_dim)
    xk = xk.view(seqlen_sum, self.n_kv_heads, self.args.head_dim)
    xv = xv.view(seqlen_sum, self.n_kv_heads, self.args.head_dim)