hidden_states = generate_output.hidden_states last_hidden_states = [item[-1][0] for item in hidden_states] last_hidden_states = torch.cat(last_hidden_states, dim=0) seg_hidden_states = get_seg_hidden_states( last_hidden_states, generate_output.sequences[0][:-1], seg_id=omg_llava.seg_token_idx )
I notice that your code does not consider beam search process in the inference. Will it reduce the quality of generated text without beam search?
hidden_states = generate_output.hidden_states last_hidden_states = [item[-1][0] for item in hidden_states] last_hidden_states = torch.cat(last_hidden_states, dim=0) seg_hidden_states = get_seg_hidden_states( last_hidden_states, generate_output.sequences[0][:-1], seg_id=omg_llava.seg_token_idx )
I notice that your code does not consider beam search process in the inference. Will it reduce the quality of generated text without beam search?