axinc-ai / ailia-models

The collection of pre-trained, state-of-the-art AI models for ailia SDK
1.91k stars 311 forks source link

ADD BARK #1126

Open kyakuno opened 1 year ago

kyakuno commented 1 year ago

https://github.com/suno-ai/bark#-update mit

ooe1123 commented 9 months ago

モデルの修正

○ bark/model.py

class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        ...
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')

class Block(nn.Module):
    ...

class GPT(nn.Module):
    ...

class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        ...
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        self.flash = False

    ...
    def forward2(self, x, past_kv, use_cache=False):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k ,v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        past_key = past_kv[[0]]
        past_value = past_kv[[1]]
        k = torch.cat((past_key, k.type(torch.float32)), dim=-2)
        v = torch.cat((past_value, v.type(torch.float32)), dim=-2)

        FULL_T = k.shape[-2]

        present = (k, v)

        if self.flash:
            if past_kv is not None:
                is_causal = False
            else:
                is_causal = True

            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout, is_causal=is_causal)
        else:
            # manual implementation of attention
            if torch.onnx.is_in_onnx_export():
                att = (q @ k.transpose(-2, -1)) * (1.0 / torch.sqrt(k.size(-1)))
            else:
                att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,FULL_T-T:FULL_T,:FULL_T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return (y, present)

class Block(nn.Module):
    ...
    def forward2(self, x, past_kv, use_cache=False):
        attn_output, prev_kvs = self.attn.forward2(self.ln_1(x), past_kv=past_kv, use_cache=use_cache)
        x = x + attn_output
        x = x + self.mlp(self.ln_2(x))
        return (x, prev_kvs)

class GPT(nn.Module):
    ...
    def forward2(self, idx, past_kv):
        device = idx.device
        b, _ = idx.size()

        if 1:   # merge_context=True
        # if 0:
            tok_emb1 = self.transformer.wte(idx[:, -256-256-1:-256-1])
            tok_emb2 = self.transformer.wte(idx[:, -256-1:-1])
            tok_emb3 = self.transformer.wte(idx[:, -1:])
            tok_emb = torch.cat([
                tok_emb1+tok_emb2, tok_emb3
            ], dim=1)
        else:   # for coarse
            tok_emb = self.transformer.wte(idx)
        _, t, _ = tok_emb.shape

        past_length = past_kv[0][0].size(-2)

        position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device)
        position_ids = position_ids.unsqueeze(0) # shape (1, t)
        assert position_ids.shape == (1, t)

        pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd)

        x = self.transformer.drop(tok_emb + pos_emb)

        new_kv = ()

        for i, block in enumerate(self.transformer.h):
            x, kv = block.forward2(x, past_kv=past_kv[[i*2,i*2+1]], use_cache=True)
            new_kv = new_kv + (kv[0],kv[1])

        new_kv = torch.cat(new_kv, 0)

        x = self.transformer.ln_f(x)

        # inference-time mini-optimization: only forward the lm_head on the very last position
        logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim

        return (logits, new_kv)

○ bark/model_fine.py

class NonCausalSelfAttention(nn.Module):
    def __init__(self, config):
        ...
        self.flash = (
            hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0
        )

    def forward(self, x):
        ...
        if self.flash:
            ...
        else:
            # manual implementation of attention
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))

class NonCausalSelfAttention(nn.Module):
    def __init__(self, config):
        ...
        self.flash = (
            hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0
        )
        self.flash = False

    def forward(self, x):
        ...
        if self.flash:
            ...
        else:
            # manual implementation of attention
            if torch.onnx.is_in_onnx_export():
                att = (q @ k.transpose(-2, -1)) * (1.0 / torch.sqrt(k.size(-1)))
            else:
                att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
ooe1123 commented 9 months ago

[text.onnx]

○ bark/generation.py

def generate_text_semantic(
    ...
):
    ...
    with _inference_mode():
        ...
        for n in range(n_tot_steps):
            if use_kv_caching and kv_cache is not None:
                x_input = x[:, [-1]]
                x_input = x
            logits, kv_cache = model(
                x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache
            )

def generate_text_semantic(
    ...
):
    ...
    with _inference_mode():
        ...
        for n in range(n_tot_steps):
            # if 1:
            if 0:
                if use_kv_caching and kv_cache is not None:
                    x_input = x[:, [-1]]
                else:
                    x_input = x
                logits, kv_cache = model(
                    x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache
                )
            else:
                if use_kv_caching and kv_cache is not None:
                    x_input = x[:, [-1]]
                else:
                    x_input = x
                    kv_cache = np.zeros((48, 16, 0, 64), dtype=np.float32)

                if 1 and n > 0:
                # if 0:
                    print("------>")
                    from torch.autograd import Variable
                    model = model.cpu()
                    xx = (
                        Variable(x_input.cpu()),
                        Variable(torch.from_numpy(kv_cache).cpu()), 
                    )
                    model.forward = model.forward2
                    torch.onnx.export(
                        model, xx, 'text.onnx',
                        input_names=["x_input", "past_kv"],
                        output_names=["logits", "kv_cache"],
                        dynamic_axes={'x_input' : [1], 'past_kv' : [2]},
                        verbose=False, opset_version=14
                    )
                    print("<------")
                    1/0

                logits, kv_cache = model.forward2(
                    x_input,
                    past_kv=torch.from_numpy(kv_cache).cuda(),
                )
                kv_cache = kv_cache.cpu().detach().numpy()
ooe1123 commented 9 months ago

[coarse.onnx]

○ bark/generation.py

def generate_coarse(
    ...
):
    ...
    with _inference_mode():
        ...
        for _ in tqdm.tqdm(range(n_window_steps), total=n_window_steps, disable=silent):
            ...
            for _ in range(sliding_window_len):
                ...
                if use_kv_caching and kv_cache is not None:
                    x_input = x_in[:, [-1]]
                else:
                    x_input = x_in

                logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_kv=kv_cache)

def generate_text_semantic(
    ...
):
    ...
    with _inference_mode():
        ...
        for _ in tqdm.tqdm(range(n_window_steps), total=n_window_steps, disable=silent):
            ...
            for _ in range(sliding_window_len):
                ...
                if use_kv_caching and kv_cache is not None:
                    x_input = x_in[:, [-1]]
                else:
                    x_input = x_in

                # if 1:
                if 0:
                    logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_kv=kv_cache)
                else:
                    if use_kv_caching and kv_cache is not None:
                        x_input = x_in[:, [-1]]
                    else:
                        x_input = x_in
                        kv_cache = np.zeros((48, 16, 0, 64), dtype=np.float32)

                    if 1:
                    # if 0:
                        print("------>")
                        from torch.autograd import Variable
                        model = model.cpu()
                        xx = (
                            Variable(x_input.cpu()),
                            Variable(torch.from_numpy(kv_cache).cpu()),
                        )
                        model.forward = model.forward2
                        torch.onnx.export(
                            model, xx, 'coarse.onnx',
                            input_names=["x_input", "past_kv"],
                            output_names=["logits", "kv_cache"],
                            dynamic_axes={'x_input' : [1], 'past_kv' : [2]},
                            verbose=False, opset_version=14
                        )
                        print("<------")
                        1/0

                    logits, kv_cache = model.forward2(
                        x_input,
                        past_kv=torch.from_numpy(kv_cache).cuda(),
                    )
                    kv_cache = kv_cache.cpu().detach().numpy()
ooe1123 commented 9 months ago

[fine.onnx]

○ bark/generation.py

def generate_fine(
    ...
):
    ...
    with _inference_mode():
        ...
        for n in tqdm.tqdm(range(n_loops), disable=silent):
            ...
            for nn in range(n_coarse, N_FINE_CODEBOOKS):
                 logits = model(nn, in_buffer)

def generate_text_semantic(
    ...
):
    ...
def generate_fine(
    ...
):
    ...
    with _inference_mode():
        ...
        for n in tqdm.tqdm(range(n_loops), disable=silent):
            ...
            for nn in range(n_coarse, N_FINE_CODEBOOKS):
                # if 0:
                if 1:
                    print("------>")
                    from torch.autograd import Variable
                    xx = (
                        Variable(torch.tensor(nn).cpu()),
                        Variable(in_buffer.cpu()),
                    )
                    model = model.cpu()
                    model.eval()
                    torch.onnx.export(
                        model, xx, 'fine.onnx',
                        input_names=["pred_idx", "idx"],
                        output_names=["logits"],
                        verbose=False, opset_version=14
                    )
                    print("<------")
                    1/0