Open kyakuno opened 1 year ago
モデルの修正
○ bark/model.py
class CausalSelfAttention(nn.Module):
def __init__(self, config):
...
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
class Block(nn.Module):
...
class GPT(nn.Module):
...
↓
class CausalSelfAttention(nn.Module):
def __init__(self, config):
...
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
self.flash = False
...
def forward2(self, x, past_kv, use_cache=False):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k ,v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
past_key = past_kv[[0]]
past_value = past_kv[[1]]
k = torch.cat((past_key, k.type(torch.float32)), dim=-2)
v = torch.cat((past_value, v.type(torch.float32)), dim=-2)
FULL_T = k.shape[-2]
present = (k, v)
if self.flash:
if past_kv is not None:
is_causal = False
else:
is_causal = True
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout, is_causal=is_causal)
else:
# manual implementation of attention
if torch.onnx.is_in_onnx_export():
att = (q @ k.transpose(-2, -1)) * (1.0 / torch.sqrt(k.size(-1)))
else:
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:,:,FULL_T-T:FULL_T,:FULL_T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.resid_dropout(self.c_proj(y))
return (y, present)
class Block(nn.Module):
...
def forward2(self, x, past_kv, use_cache=False):
attn_output, prev_kvs = self.attn.forward2(self.ln_1(x), past_kv=past_kv, use_cache=use_cache)
x = x + attn_output
x = x + self.mlp(self.ln_2(x))
return (x, prev_kvs)
class GPT(nn.Module):
...
def forward2(self, idx, past_kv):
device = idx.device
b, _ = idx.size()
if 1: # merge_context=True
# if 0:
tok_emb1 = self.transformer.wte(idx[:, -256-256-1:-256-1])
tok_emb2 = self.transformer.wte(idx[:, -256-1:-1])
tok_emb3 = self.transformer.wte(idx[:, -1:])
tok_emb = torch.cat([
tok_emb1+tok_emb2, tok_emb3
], dim=1)
else: # for coarse
tok_emb = self.transformer.wte(idx)
_, t, _ = tok_emb.shape
past_length = past_kv[0][0].size(-2)
position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0) # shape (1, t)
assert position_ids.shape == (1, t)
pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd)
x = self.transformer.drop(tok_emb + pos_emb)
new_kv = ()
for i, block in enumerate(self.transformer.h):
x, kv = block.forward2(x, past_kv=past_kv[[i*2,i*2+1]], use_cache=True)
new_kv = new_kv + (kv[0],kv[1])
new_kv = torch.cat(new_kv, 0)
x = self.transformer.ln_f(x)
# inference-time mini-optimization: only forward the lm_head on the very last position
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
return (logits, new_kv)
○ bark/model_fine.py
class NonCausalSelfAttention(nn.Module):
def __init__(self, config):
...
self.flash = (
hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0
)
def forward(self, x):
...
if self.flash:
...
else:
# manual implementation of attention
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
↓
class NonCausalSelfAttention(nn.Module):
def __init__(self, config):
...
self.flash = (
hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0
)
self.flash = False
def forward(self, x):
...
if self.flash:
...
else:
# manual implementation of attention
if torch.onnx.is_in_onnx_export():
att = (q @ k.transpose(-2, -1)) * (1.0 / torch.sqrt(k.size(-1)))
else:
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
[text.onnx]
○ bark/generation.py
def generate_text_semantic(
...
):
...
with _inference_mode():
...
for n in range(n_tot_steps):
if use_kv_caching and kv_cache is not None:
x_input = x[:, [-1]]
x_input = x
logits, kv_cache = model(
x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache
)
↓
def generate_text_semantic(
...
):
...
with _inference_mode():
...
for n in range(n_tot_steps):
# if 1:
if 0:
if use_kv_caching and kv_cache is not None:
x_input = x[:, [-1]]
else:
x_input = x
logits, kv_cache = model(
x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache
)
else:
if use_kv_caching and kv_cache is not None:
x_input = x[:, [-1]]
else:
x_input = x
kv_cache = np.zeros((48, 16, 0, 64), dtype=np.float32)
if 1 and n > 0:
# if 0:
print("------>")
from torch.autograd import Variable
model = model.cpu()
xx = (
Variable(x_input.cpu()),
Variable(torch.from_numpy(kv_cache).cpu()),
)
model.forward = model.forward2
torch.onnx.export(
model, xx, 'text.onnx',
input_names=["x_input", "past_kv"],
output_names=["logits", "kv_cache"],
dynamic_axes={'x_input' : [1], 'past_kv' : [2]},
verbose=False, opset_version=14
)
print("<------")
1/0
logits, kv_cache = model.forward2(
x_input,
past_kv=torch.from_numpy(kv_cache).cuda(),
)
kv_cache = kv_cache.cpu().detach().numpy()
[coarse.onnx]
○ bark/generation.py
def generate_coarse(
...
):
...
with _inference_mode():
...
for _ in tqdm.tqdm(range(n_window_steps), total=n_window_steps, disable=silent):
...
for _ in range(sliding_window_len):
...
if use_kv_caching and kv_cache is not None:
x_input = x_in[:, [-1]]
else:
x_input = x_in
logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_kv=kv_cache)
↓
def generate_text_semantic(
...
):
...
with _inference_mode():
...
for _ in tqdm.tqdm(range(n_window_steps), total=n_window_steps, disable=silent):
...
for _ in range(sliding_window_len):
...
if use_kv_caching and kv_cache is not None:
x_input = x_in[:, [-1]]
else:
x_input = x_in
# if 1:
if 0:
logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_kv=kv_cache)
else:
if use_kv_caching and kv_cache is not None:
x_input = x_in[:, [-1]]
else:
x_input = x_in
kv_cache = np.zeros((48, 16, 0, 64), dtype=np.float32)
if 1:
# if 0:
print("------>")
from torch.autograd import Variable
model = model.cpu()
xx = (
Variable(x_input.cpu()),
Variable(torch.from_numpy(kv_cache).cpu()),
)
model.forward = model.forward2
torch.onnx.export(
model, xx, 'coarse.onnx',
input_names=["x_input", "past_kv"],
output_names=["logits", "kv_cache"],
dynamic_axes={'x_input' : [1], 'past_kv' : [2]},
verbose=False, opset_version=14
)
print("<------")
1/0
logits, kv_cache = model.forward2(
x_input,
past_kv=torch.from_numpy(kv_cache).cuda(),
)
kv_cache = kv_cache.cpu().detach().numpy()
[fine.onnx]
○ bark/generation.py
def generate_fine(
...
):
...
with _inference_mode():
...
for n in tqdm.tqdm(range(n_loops), disable=silent):
...
for nn in range(n_coarse, N_FINE_CODEBOOKS):
logits = model(nn, in_buffer)
↓
def generate_text_semantic(
...
):
...
def generate_fine(
...
):
...
with _inference_mode():
...
for n in tqdm.tqdm(range(n_loops), disable=silent):
...
for nn in range(n_coarse, N_FINE_CODEBOOKS):
# if 0:
if 1:
print("------>")
from torch.autograd import Variable
xx = (
Variable(torch.tensor(nn).cpu()),
Variable(in_buffer.cpu()),
)
model = model.cpu()
model.eval()
torch.onnx.export(
model, xx, 'fine.onnx',
input_names=["pred_idx", "idx"],
output_names=["logits"],
verbose=False, opset_version=14
)
print("<------")
1/0
https://github.com/suno-ai/bark#-update mit