jzhang38 / EasyContext

Memory optimization and training recipes to extrapolate language models' context length to 1 million tokens, with minimal hardware.
Apache License 2.0
529 stars 33 forks source link

Not the real auto-regressive decoding mode ? #8

Open microhu opened 2 months ago

microhu commented 2 months ago

Dear author,

In below eval_foreard function, it seems not the real autoregressive decoding. since you concate the input and answer_ids together to form the new input_ids, it performs decoding in the teacher-force mode, not the real auto-regressive decoding.

am I correct?


def eval_forward(accelerator, model, input_ids, pad_id, answer_ids):

first append labels to input_ids

prompt_length = input_ids.shape[1]
labels_length = answer_ids.shape[1]
input_ids = torch.cat([input_ids, answer_ids], dim=1)
# second pad input_ids to the multiple of accelerator.num_processes
pad_tensor = torch.tensor(
    [pad_id]
    * (
        (accelerator.num_processes * 2)
        - input_ids.shape[1] % (accelerator.num_processes * 2)
    )
).unsqueeze(0)
input_ids = torch.cat([input_ids, pad_tensor], dim=1)
position_ids = (
    torch.arange(input_ids.shape[1]).unsqueeze(0).expand(input_ids.shape[0], -1)
)
prepared = prepare_seq_parallel_inputs(
    "zigzag_ring_attn",
    input_ids,
    position_ids,
    None,
    accelerator.process_index,
    accelerator.num_processes,
    accelerator.device,
)
local_input_ids = prepared["local_input_ids"]
local_position_ids = prepared["local_position_ids"]
with torch.inference_mode():
    logits = model(
        local_input_ids,
        position_ids=local_position_ids,
        use_cache=False,
    ).logits
    pred = logits.argmax(dim=-1)
jzhang38 commented 2 months ago

I used PPL-based eval strategy to avoid saving the KV cache. It is only counted as correct when the output logits of the entire answer tokens are the highest. It will produce the same effect as generation-based eval with greedy decoding.