Open wangyf8848 opened 3 weeks ago
Hello, I am curious as to why we can obtain cond_logits and uncond_logits by simply concatenating two identical 'x' on the batch dimension and forwarding it. Additionally, what is the meaning of the parameter cfg_interval?
def decode_one_token(model, x: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, cfg_flag: bool, sampling_kwargs): assert input_pos.shape[-1] == 1 if cfg_scale > 1.0: xcombined = torch.cat([x, x])** logits, = model(x_combined, cond_idx=None, input_pos=input_pos) logits_combined = logits cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0) if cfg_flag: logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale else: logits = condlogits else: logits, = model(x, cond_idx=None, input_pos=input_pos) return sample(logits, **sampling_kwargs)
def decode_n_tokens( model, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, cfg_scale: float, cfg_interval: int, sampling_kwargs): new_tokens, new_probs = [], [] cfg_flag = True for i in range(num_new_tokens): with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here if cfg_interval > -1 and i > cfg_interval: cfg_flag = False next_token, next_prob = decode_one_token( model, cur_token, input_pos, cfg_scale, cfg_flag, sampling_kwargs ) input_pos += 1 new_tokens.append(next_token.clone()) new_probs.append(next_prob.clone()) cur_token = next_token.view(-1, 1) return new_tokens, new_probs
Hello, I am curious as to why we can obtain cond_logits and uncond_logits by simply concatenating two identical 'x' on the batch dimension and forwarding it. Additionally, what is the meaning of the parameter cfg_interval?
def decode_one_token(model, x: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, cfg_flag: bool, sampling_kwargs): assert input_pos.shape[-1] == 1 if cfg_scale > 1.0: xcombined = torch.cat([x, x])** logits, = model(x_combined, cond_idx=None, input_pos=input_pos) logits_combined = logits cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0) if cfg_flag: logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale else: logits = condlogits else: logits, = model(x, cond_idx=None, input_pos=input_pos) return sample(logits, **sampling_kwargs)
def decode_n_tokens( model, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, cfg_scale: float, cfg_interval: int, sampling_kwargs): new_tokens, new_probs = [], [] cfg_flag = True for i in range(num_new_tokens): with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here if cfg_interval > -1 and i > cfg_interval: cfg_flag = False next_token, next_prob = decode_one_token( model, cur_token, input_pos, cfg_scale, cfg_flag, sampling_kwargs ) input_pos += 1 new_tokens.append(next_token.clone()) new_probs.append(next_prob.clone()) cur_token = next_token.view(-1, 1) return new_tokens, new_probs